BEiTのlast_hidden_stateを理解する
https://huggingface.co/transformers/model_doc/beit.html#transformers.BeitModel
を参照すると,以下のようにBEiTから特徴を抽出するコードが書かれている.
from transformers import BeitFeatureExtractor, BeitModel from PIL import Image import requests url = 'http://images.cocodataset.org/val2017/000000039769.jpg' image = Image.open(requests.get(url, stream=True).raw) feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k') model = BeitModel.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k') inputs = feature_extractor(images=image, return_tensors="pt") outputs = model(**inputs) last_hidden_states = outputs.last_hidden_state
このコードにあるlast_hidden_statesはBEiTから抽出された特徴であり,これをLSTMやMLPに入力しなおすことによって固有表現抽出器や文章分類器を実装することができる.
こちらのコードでは,size=(640, 480)の画像をtorch.Size([1,3,224,224])に変換し,このtensorをBEiT入力した時に抽出された特徴がlast_hidden_statesになる.last_hidden_statesのサイズは,(1, 197, 768)となっており,それぞれの数字は(batch_size, sequence_length, hidden_size)を示している.
実際に上のコードを実行すると,last_hidden_statesには(1, 197, 768)のtensorが格納されている.このとき1はbatch_sizeを表し,197はsequence_lengthを表し,768はhidden_sizeを表すことはわかるが,BEiTにおいてsequence_lengthは何を表しているのであろうかと疑問に思うことがあるはずだ.
この疑問は論文中の図を参照することによって解消される.
BEiTはご存じのように,画像をバッチ化して扱うimage transformerであり,バッチ化された画像をTokenとして扱っている.ここまで言えば理解できると思うが,sequence_lengthのsequenceはバッチ化された画像(Token)のsequenceを表しているのだ.こう考えると,224×224の画像を16×16のバッチ化すると,バッチ化された画像(Token)の総数は16×16=196となる.「あれ?さっき書かれていたサイズは(1, 197, 768)で違うじゃないか」と思うかもしれないが,BEiTではsequenceの先頭にスペシャルトークン[S]が挿入されているため,196+1=197がsequence_lengthとなるのだ.
自然言語処理で提案されたOriginalのBERTでは,スペシャルトークンとして文頭に[CLS]と文の切れ目を表す[SEP]トークンが挿入されてるが,画像の場合,文章のように切れ目を定義するのが困難なため,文頭を示すスペシャルトークンのみを挿入したのであろう.
ディスカッション
コメント一覧
まだ、コメントがありません