PyTorchの「mat1 and mat2 shapes cannot be multiplied」の直し方
全結合層を通すと、こう止まることがある。
RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x128 and 256x10)これは行列の掛け算ができない、という意味だ。nn.Linear の入力サイズ(in_features)と、実際に流し込んでいるテンソルの最後の次元が合っていない。上の例なら、入力の最終次元は128なのに、Linearは256を期待している。
まず形を確認する
エラーの数字を読む。64x128 の128が実際の入力の特徴次元、256x10 の256が層の期待値だ。ここを一致させる。
print(x.shape) # 例: torch.Size([64, 128])in_features を入力に合わせる
self.fc = nn.Linear(128, 10) # 256 ではなく、入力の最終次元 128 に合わせるCNNの後ろにLinearを置くときはflattenを忘れない
畳み込みの出力は多次元なので、平坦化してからLinearに入れる。サイズはチャネル×高さ×幅で決まる。
x = x.flatten(1) # (N, C, H, W) -> (N, C*H*W)print(x.shape[1]) # この値を Linear の in_features にするまとめ
- 原因はLinearのin_featuresと入力の最終次元の不一致
- エラーの数字(NxM と PxQ)でMとPを合わせる
- 入力 shape を print して確認する
- CNNの後は flatten してから次元を数える