Pytorch v0.4のコードをv0.3で動かす際には.dataに注意[Pytorch]
Pytorchのコードを見ているとミニバッチごとのlossやaccuracyを計算する際、.dataを用いて値を取り出されることが頻繁にある。
よくある例:
for i in range(0, 2 * POS_NEG_SAMPLES, BATCH_SIZE): inp, target = dis_inp[i:i + BATCH_SIZE], dis_target[i:i + BATCH_SIZE] dis_opt.zero_grad() out = discriminator.batchClassify(inp) loss_fn = nn.BCELoss() loss = loss_fn(out, target) loss.backward() dis_opt.step() total_loss += loss.data # .dataを使ってlossから値を取り出しtotal_lossに蓄積 total_acc += torch.sum((out>0.5)==(target>0.5)).data # .dataを使ってaccuracyを取り出しtotal_accに蓄積
.dataを用いて値を取り出す上のコードは、pytorch v0.4で動かすならエラーがでない。
しかし、上のコードをpytorch v0.3で動かすと下のエラーが発生する。
# pytorch v0.3で動かした際のエラー TypeError: div_ received an invalid combination of arguments - got (float), but expected one of: * (int value) didn't match because some of the arguments have invalid types: (!float!) * (torch.cuda.ByteTensor other) didn't match because some of the arguments have invalid types: (!float!)
なぜv0.4で動いていたコードがv0.3で動かなくなってしまうのか?
その原因は、pytorch v0.3でVariableの型の行列から値を取り出す際に.dataを使用するとデータの型がVariableからTensorに変わってしまうからである。
簡単な例で問題をみていこう。
# Pytorch v0.4で動かすと問題ないのだが、v0.3で動かすとエラーが発生するコード # 本コードではv0.3で動かしたことを想定 import torch from torch.autograd import Variable x = Variable(torch.Tensor([1,2,3])) # ListからTensorに変換し、更にTensorをVariableに変換 y = Variable(torch.Tensor([4,5,6])).data # .dataを使ってVariableから値を取り出す。 # 返り値を表示。xにはVariableが格納されていることがわかる。 x >> Variable containing: 1 2 3 [torch.FloatTensor of size 3] # 返り値を表示。yはVariableではなくTensorが格納されてしまった。 y >> 4 5 6 [torch.FloatTensor of size 3] # TensorとVariableを加算するとエラーが発生 x+y >> Traceback (most recent call last): File "<ipython-input-188-259706549f3d>", line 1, in <module> x+y RuntimeError: add() received an invalid combination of arguments - got (torch.FloatTensor), but expected one of: * (float other, float alpha) * (Variable other, float alpha)
上の例のようにpytorch v0.3では、.dataで値を取り出すと型違いによるエラーが発生してしまう為、v0.4で動いていたコードが動かなくなる。
v0.4ではVariableとTensorが統合された為、型違いによるエラーが発生しないようだ。
v0.4公式がまとめた変更点(英語)
v0.4の変更点が日本語でまとめられた記事
ちなみに、v0.4の公式ドキュメントでは.dataを使うことはunsafeだと言っており、代わりに.detachを使うことが推奨されている。
What about .data ? のセクションに .dataを使うことの危険性が述べられている。
.dataで値を取り出した場合、xに対する変更がautogradで追跡できない為、危険視されているようだ。
最後に.detachを使用して書いた、v0.3とv0.4ともにエラーが発生しないコードを載せる。
# v0.3とv0.4ともにエラーが発生しないコード import torch from torch.autograd import Variable x = Variable(torch.Tensor([1,2,3])) # ListからTensorに変換し、更にTensorをVariableに変換 z = Variable(torch.Tensor([7,8,9])).detach() z >> Variable containing: 7 8 9 [torch.FloatTensor of size 3] # v0.3とv0.4ともにエラー無しで加算が行える x+z >> Variable containing: 8 10 12 [torch.FloatTensor of size 3]
追記
あまり良い方法ではないのだが、lossやaccuracyなど1×1の値を取り出す場合には、.data[0]を使う手もある。
.data[0]は、下のIrfan_Buluさんの回答を見て知った。
ディスカッション
コメント一覧
まだ、コメントがありません