Pytorch v0.4のコードをv0.3で動かす際には.dataに注意[Pytorch]

7月 27, 2020

Pytorchのコードを見ているとミニバッチごとのlossやaccuracyを計算する際、.dataを用いて値を取り出されることが頻繁にある。

よくある例:

for i in
range(0, 2 * POS_NEG_SAMPLES, BATCH_SIZE):
inp, target = dis_inp[i:i + BATCH_SIZE], dis_target[i:i + BATCH_SIZE]
dis_opt.zero_grad()
out = discriminator.batchClassify(inp)
loss_fn = nn.BCELoss()
loss = loss_fn(out, target)
loss.backward()
dis_opt.step()
total_loss += loss.data # .dataを使ってlossから値を取り出しtotal_lossに蓄積
total_acc += torch.sum((out>0.5)==(target>0.5)).data # .dataを使ってaccuracyを取り出しtotal_accに蓄積

.dataを用いて値を取り出す上のコードは、pytorch v0.4で動かすならエラーがでない。

しかし、上のコードをpytorch v0.3で動かすと下のエラーが発生する。

# pytorch v0.3で動かした際のエラー
TypeError: div_ received an invalid combination of arguments - got (float), but expected one of:
* (int value)
didn't match because some of the arguments have invalid types: (!float!)
* (torch.cuda.ByteTensor other)
didn't match because some of the arguments have invalid types: (!float!)

なぜv0.4で動いていたコードがv0.3で動かなくなってしまうのか?

その原因は、pytorch v0.3でVariableの型の行列から値を取り出す際に.dataを使用するとデータの型がVariableからTensorに変わってしまうからである。

簡単な例で問題をみていこう。

# Pytorch v0.4で動かすと問題ないのだが、v0.3で動かすとエラーが発生するコード
# 本コードではv0.3で動かしたことを想定
import torch
from torch.autograd import Variable
x = Variable(torch.Tensor([1,2,3])) # ListからTensorに変換し、更にTensorをVariableに変換
y = Variable(torch.Tensor([4,5,6])).data # .dataを使ってVariableから値を取り出す。
# 返り値を表示。xにはVariableが格納されていることがわかる。
x
>>
Variable containing:
1
2
3
[torch.FloatTensor of size 3]
# 返り値を表示。yはVariableではなくTensorが格納されてしまった。
y
>>
4
5
6
[torch.FloatTensor of size 3]
# TensorとVariableを加算するとエラーが発生
x+y
>>
Traceback (most recent call last):
File "<ipython-input-188-259706549f3d>", line 1, in <module>
x+y
RuntimeError: add() received an invalid combination of arguments - got (torch.FloatTensor), but expected one of:
* (float other, float alpha)
* (Variable other, float alpha)

上の例のようにpytorch v0.3では、.dataで値を取り出すと型違いによるエラーが発生してしまう為、v0.4で動いていたコードが動かなくなる。

v0.4ではVariableとTensorが統合された為、型違いによるエラーが発生しないようだ。

v0.4公式がまとめた変更点(英語)

pytorch.org

v0.4の変更点が日本語でまとめられた記事

qiita.com

ちなみに、v0.4の公式ドキュメントでは.dataを使うことはunsafeだと言っており、代わりに.detachを使うことが推奨されている。

pytorch.org

What about .data ? のセクションに .dataを使うことの危険性が述べられている。

.dataで値を取り出した場合、xに対する変更がautogradで追跡できない為、危険視されているようだ。

最後に.detachを使用して書いた、v0.3とv0.4ともにエラーが発生しないコードを載せる。

# v0.3とv0.4ともにエラーが発生しないコード
import torch
from torch.autograd import Variable
x = Variable(torch.Tensor([1,2,3])) # ListからTensorに変換し、更にTensorをVariableに変換
z = Variable(torch.Tensor([7,8,9])).detach()
z
>>
Variable containing:
7
8
9
[torch.FloatTensor of size 3]
# v0.3とv0.4ともにエラー無しで加算が行える
x+z
>>
Variable containing:
8
10
12
[torch.FloatTensor of size 3]

追記

あまり良い方法ではないのだが、lossやaccuracyなど1×1の値を取り出す場合には、.data[0]を使う手もある。

.data[0]は、下のIrfan_Buluさんの回答を見て知った。

discuss.pytorch.org

PyTorch

Posted by vastee