なぜBERTの隠れ層の次元が768なのか?

9月 20, 2022

BERTを用いて文や単語から特徴抽出していると、取り出されたベクトルに出現する数字が、バッチサイズを示しているのか、トークン数を示しているのかなどが分からなくなってしまう。見覚えがあるが何の数字だったか混乱した場合は、「BERT 768」などと検索すれば良いのだが、なるべく覚えていきたいので自分用にメモする。

768: hidden state
・隠れ層の次元

なぜ768なのか?についてTwitterで議論が存在する。この回答によると、TPUの行列乗算ユニットは128の浮動小数点の幅をもっており、このユニットを複数もっている。このため、TPUで行列乗算を行う場合、各次元が128の倍数(あるいはもっと良いのは256の倍数)である行列に対して最も高速になる。らしい。

Inception-v4などの有名論文を数多くもつCristian Szegedy氏の返信

ついでに、BERTの隠れ層以外にも以下のトークン数を紹介しよう。

31090: Vocab size
 ・BERTのボキャブラリに収録されるトークン数。
 ・ボキャブラリファイル(vocab.txt)を独自のものに変更するとここの数字も変化。

BERTの各引数についてわかりやすく解説されているブログがあるため、自身の参考用に転載させていただく。

BERT modelのforwardの引数と出力の関係

transformersのmodeling_bert.pyを眺めるとわかります。

last_hidden_state: 最終層の隠れ層のベクトル(1xtoken数x各tokenのベクトル次元)

pooler_output: 最終層の隠れ層のベクトルの内、最初のtokenのみを取り出してdense+tanh()する操作(最終層のCLSに対応するベクトルの抽出)

hidden_states: 入力のembeddings、最終層の隠れ層のベクトルも含めた、全層の隠れベクトルのリスト(12段なら1xtoken数x各tokenのベクトル次元のtorch.tensorが13個できる)。リストの後ろの要素ほど最終層に近い層のベクトル。forwardの引数にoutput_hidden_states=Trueを入れると出力される

attentions: 各段のtransformerでのforward計算でのattentionのリスト(12段なら1xhead数xtoken数xtoken数のtorch.tensorが12個できる)。forwardの引数にoutput_attentions=Trueを入れると出力される

https://snowman-88888.hatenablog.com/entry/2020/08/21/055414より転載

last_hidden_stateの出力例: (1, n, 768) -> token数(n)に対応する768次元のベクトル。BERTを特徴抽出に用いる際は、この768次元のベクトルを用いることが多い。

また、BERTから抽出された特徴を用いて予測タスクを実行する場合、この768次元に全結合層を接続することにより予測モデルを学習することが多い。多層のNNを接続するモデルも考えられるが、特徴の数の割にパラメータ数が多くなるため、あまり好ましくないらしい。

BERTの入力となる文には、文の切れ目を表す[CLS]や[SEP]トークンが含まれているが、これらのトークンには文全体の特徴が集約されやすいことが知られている。このため、文全体の特徴として[CLS]のベクトルが使われることもある。


Head数に関するメモ

2-3-1 Multi-head Attention

論文では、いくつもの小さなAttentionを並列に繋げると、パフォーマンスが上がったと言われています。それぞれのAttentionをheadと呼ぶので、Multi-head Attentionと呼ばれています。Attention is all you needでは、全体として512次元のtensorが使われていて、この総数はhead数によりません。head=4ならば各headのデータ次元は128になりますし、head=8ならば64次元になります。

acceluniverse.com/blog/developers/2019/08/attention.html

BERT

Posted by vastee