torch.saveとstate_dictの違い

Pytorchで学習したモデルを保存する時は

torch.save(model.state_dict(), model_path)

でモデルを保存することが推奨される.

state_dictがsaveで直接保存するのとの違いは,そのファイルサイズである.state_dictで保存するときはネットワーク構造や各レイヤーの引数(入出力チャンネル数やカーネルサイズなど)などの無駄な情報を保存せずに済む.

現在の深層学習フレームワークの主流は,define-by-runなので,ネットワーク構造などの情報はスクリプトに書いてあるはずなので,わざわざその全てを保存する必要はない.

実際pytorchの公式でもstate_dictで保存することが推奨されている.

https://pytorch.org/docs/master/notes/serialization.html

参考までにstate_dictを用いたモデルの保存と読み込み方を紹介

保存

model_path = 'model.pth'
torch.save(model.state_dict(), model_path)

読込

model_path = 'model.pth'
model.load_state_dict(torch.load(model_path))

PyTorch

Posted by vastee