BEiTのlast_hidden_stateを理解する

https://huggingface.co/transformers/model_doc/beit.html#transformers.BeitModel

を参照すると,以下のようにBEiTから特徴を抽出するコードが書かれている.

from transformers import BeitFeatureExtractor, BeitModel
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
model = BeitModel.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')

inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state

このコードにあるlast_hidden_statesはBEiTから抽出された特徴であり,これをLSTMやMLPに入力しなおすことによって固有表現抽出器や文章分類器を実装することができる.

こちらのコードでは,size=(640, 480)の画像をtorch.Size([1,3,224,224])に変換し,このtensorをBEiT入力した時に抽出された特徴がlast_hidden_statesになる.last_hidden_statesのサイズは,(1, 197, 768)となっており,それぞれの数字は(batch_size, sequence_length, hidden_size)を示している.

実際に上のコードを実行すると,last_hidden_statesには(1, 197, 768)のtensorが格納されている.このとき1はbatch_sizeを表し,197はsequence_lengthを表し,768はhidden_sizeを表すことはわかるが,BEiTにおいてsequence_lengthは何を表しているのであろうかと疑問に思うことがあるはずだ.

この疑問は論文中の図を参照することによって解消される.

技術論文】画像をパッチ化して扱う Image Transformer で BERT 型の自己教師あり学習 BEiT  を提案、マスクされた画像パッチの連続値での予測 (回帰)ではなく、dVAE  による画像トークナイザで得た画像トークンの分類で事前学習。HuggingFace でモデルが公開。 | ツイレポ

BEiTはご存じのように,画像をバッチ化して扱うimage transformerであり,バッチ化された画像をTokenとして扱っている.ここまで言えば理解できると思うが,sequence_lengthのsequenceはバッチ化された画像(Token)のsequenceを表しているのだ.こう考えると,224×224の画像を16×16のバッチ化すると,バッチ化された画像(Token)の総数は16×16=196となる.「あれ?さっき書かれていたサイズは(1, 197, 768)で違うじゃないか」と思うかもしれないが,BEiTではsequenceの先頭にスペシャルトークン[S]が挿入されているため,196+1=197がsequence_lengthとなるのだ.

自然言語処理で提案されたOriginalのBERTでは,スペシャルトークンとして文頭に[CLS]と文の切れ目を表す[SEP]トークンが挿入されてるが,画像の場合,文章のように切れ目を定義するのが困難なため,文頭を示すスペシャルトークンのみを挿入したのであろう.

BERT,Deep Learning

Posted by vastee