Weight Sharering (重み共有)の実装方法 [Pytorch]

Pytorchのチュートリアルから、weight sharering(重み共有)の実装方法を紹介する。

今回扱うのは以下のチュートリアルである。

Learning PyTorch with Examples — PyTorch Tutorials 1.10.1+cu102 documentation

ここでの説明について、翻訳すると以下のようになる。

PyTorch: Control Flow + Weight Sharing

As an example of dynamic graphs and weight sharing, we implement a very strange model: a third-fifth order polynomial that on each forward pass chooses a random number between 3 and 5 and uses that many orders, reusing the same weights multiple times to compute the fourth and fifth order.

For this model we can use normal Python flow control to implement the loop, and we can implement weight sharing by simply reusing the same parameter multiple times when defining the forward pass.

We can easily implement this model as a Module subclass:

PyTorch: 制御フロー+重み共有

動的グラフと重みの共有の例として、非常に奇妙なモデルを実装します。3次から5次の多項式は、各フォワードパスで3から5の間の乱数を選び、その数の次数を使用し、4次と5次の計算には同じ重みを複数回再利用します。

このモデルでは、ループを実装するために通常のPythonのフロー制御を使うことができ、フォワードパスを定義するときに同じパラメータを複数回再利用するだけで、重みの共有を実装することができます。

このモデルは、モジュールのサブクラスとして簡単に実装することができます。

重み共有の実装に関して

重み共有をPytorchで実装する場合、上記説明に書いてある通り、forループを用いることで簡単に実装することができる。ここで注意したいのが、モジュール(ここではmiddle_layer)を使いまわした場合、そのモジュールの重みが共有されるという点だ。つまり、forループで呼び出したそれぞれのモジュールは独立ではない。例えば3回ループさせた場合、3つの新しい独立のモジュールがネットワークに追加されるのではなく、1つのモジュールの重みが3回更新されることになる。

重みを共有した場合の実装は以下のようになる。

# -*- coding: utf-8 -*-
import random
import torch

class DynamicNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        コンストラクタ:「middle_linear」を追加
        """
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(D_in, H)
        self.middle_linear = torch.nn.Linear(H, H)
        self.output_linear = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        順伝播の「middle_layer」を0〜3回forループで繰り返します。
        「middle_layer」をループ回数使い回しています。
        
        順伝播を定義するときには、Pythonのループや制御文を利用することができます。
        
        ここでは、計算グラフを定義するときに同じモジュールを何度も再利用しても大丈夫なことを示しています。
        これはLuaTorch(昔のPyTorch)からの改善点で、昔は各モジュールは1回しか使えませんでした。
        """
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(3):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred

N, D_in, H, D_out = 64, 1000, 100, 10

x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

model = DynamicNet(D_in, H, D_out)

criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

# ネットワーク可視化のため
y_pred = model(x)

# モデル構造の書き出し
from torchviz import make_dot
image = make_dot(y_pred, params=dict(model.named_parameters()))
image.format = "png"
image.render("sample")

この実装で注意してもらいたのが、

for _ in range(3):
    h_relu = self.middle_linear(h_relu).clamp(min=0)

の部分であり、ここで複数回ループすることによって、複数のmiddle_linearで重みが共有されることとなる。

ループを0回回したとき(通常のモデル)の可視化結果は以下のようになる。

ループを3回回したときの可視化結果は以下のようになる。

これらの結果を見比べると、middle_linear.weightから3本の矢印がのびており、それぞれのモジュールにつながっていることがわかる。

まとめると、Pytorchではforループを用いることで重みの共有が実装できる。

参考

本記事に掲載したコードは以下のブログに掲載されているコードを参考に作成しました。

【PyTorchチュートリアル】LEARNING PYTORCH WITH EXAMPLESの6、nnモジュールでクラス化 | ぱんだクリップ (panda-clip.com)

実際にforループを用いた重み共有が用いられているコード

MLP-Mixer-pytorch/mlp-mixer.py at master · rishikksh20/MLP-Mixer-pytorch · GitHub

for _ in range(depth):
            self.mixer_blocks.append(MixerBlock(dim, self.num_patch, token_dim, channel_dim))

で複数のMLPの重みが共有されている。

PyTorch

Posted by vastee