torch.saveとstate_dictの違い
Pytorchで学習したモデルを保存する時は
torch.save(model.state_dict(), model_path)
でモデルを保存することが推奨される.
state_dictがsaveで直接保存するのとの違いは,そのファイルサイズである.state_dictで保存するときはネットワーク構造や各レイヤーの引数(入出力チャンネル数やカーネルサイズなど)などの無駄な情報を保存せずに済む.
現在の深層学習フレームワークの主流は,define-by-runなので,ネットワーク構造などの情報はスクリプトに書いてあるはずなので,わざわざその全てを保存する必要はない.
実際pytorchの公式でもstate_dictで保存することが推奨されている.
https://pytorch.org/docs/master/notes/serialization.html
参考までにstate_dictを用いたモデルの保存と読み込み方を紹介
保存
model_path = 'model.pth' torch.save(model.state_dict(), model_path)
読込
model_path = 'model.pth' model.load_state_dict(torch.load(model_path))
ディスカッション
コメント一覧
まだ、コメントがありません