PyTorchの「expected scalar type Half but found Float」の直し方
結論を先に。これはモデルと入力データ(または一部の層やテンソル)の精度(dtype)が混在しているのが原因です。半精度(fp16 / half)に揃えるか、AMPを正しく使えば直ります。
RuntimeError: expected scalar type Half but found Float原因
model.half() で半精度にしたのに、入力テンソルが float32 のままだった、あるいは AMP(自動混合精度)の外で手動の half テンソルを混ぜた、といった場合に起きます。PyTorch は異なる dtype のテンソル同士を計算できません。
対処
- 入力をモデルに合わせる。モデルを half にしたなら入力も half にします。
x = x.half() # 入力を半精度に# あるいはモデルを単精度に戻すmodel = model.float()- 高速化のための混合精度なら、手動
.half()をやめて AMP を使う。計算を autocast の中に入れます。
with torch.cuda.amp.autocast(): out = model(x) loss = criterion(out, y)- 一部だけ half になっていないか確認する。カスタム層や自前で確保したバッファが float のまま残っていることがあります。
まとめ
- 原因はモデルとデータの dtype の混在
- 入力を
x.half()でモデルに合わせる、またはmodel.float()で戻す - 高速化目的なら手動 half をやめて
autocastを使う - カスタム層やバッファの dtype も確認する