Transformerはn-steps先をどのように推論しているのか?

TransformerのPytorch実装では、学習時に入力をエンコードして、出力を生成するためにnステップ反復し続けるのではなく、ステップごとに入力とターゲットのテンソルを渡す必要がある。では、推論時にTransformerはどうやってターゲットを推論するのであろうか?

その答えを知るのに最も良い方法は、The annotated Transformerの実装をみることである。特に、run_epochの中のgreedy_decodeという関数が参考になるはずだ。

下記にgreedy_decodeのコードを載せる

def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len-1):
        out = model.decode(memory, src_mask, 
                           Variable(ys), 
                           Variable(subsequent_mask(ys.size(1))
                                    .type_as(src.data)))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim = 1)
        next_word = next_word.data[0]
        ys = torch.cat([ys, 
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    return ys

この関数をみると、model.encode⇒model.decode⇒model.generatorの順で推論が行われていることがわかる。そして、model.decodeに入力されているのはmodel.encodeの出力のmemoryとなるため、推論時にはエンコーダの入力がデコーダに入力されていることになる。

まとめると、Transformerは、

学習時には、ソーストークンとターゲットトークンの両方を Transformerに渡す。これは、Teacher Forcingを使用して LSTM または GRU で行うことと同様である。また、Transformer デコーダーでは、現在および将来のトークンを先読みしてしまうことを回避するために、マスキングを適用する必要がある。

推論時には、ターゲットトークンがない (それが予測しようとしているものであるため)。この場合、最初のステップでのデコーダー入力はエンコーダの入力になり、最初のトークンを予測。次に、次のタイムステップの入力を準備して、予測を前のタイムステップ入力に追加し、2 番目のトークンの予測を取得。このように各タイムステップで、過去の位置の計算を繰り返していることになる。実際の実装では、これらの状態はタイムステップごとに再計算される代わりにキャッシュされる。

また、model.generatorは推論時にしか登場しないが、これは学習済みのモデルにMLPとsoftmaxで構成されたヘッドが付随している形となる。

参考

本解説は下記QAを基に作成しました

https://datascience.stackexchange.com/questions/90441/how-does-the-transformer-predict-n-steps-into-the-future?utm_source=pocket_reader

Transformer

Posted by vastee