PyTorchの「expected scalar type Half but found Float」の直し方

結論を先に。これはモデルと入力データ(または一部の層やテンソル)の精度(dtype)が混在しているのが原因です。半精度(fp16 / half)に揃えるか、AMPを正しく使えば直ります。

RuntimeError: expected scalar type Half but found Float

原因

model.half() で半精度にしたのに、入力テンソルが float32 のままだった、あるいは AMP(自動混合精度)の外で手動の half テンソルを混ぜた、といった場合に起きます。PyTorch は異なる dtype のテンソル同士を計算できません。

対処

  1. 入力をモデルに合わせる。モデルを half にしたなら入力も half にします。
x = x.half() # 入力を半精度に
# あるいはモデルを単精度に戻す
model = model.float()
  1. 高速化のための混合精度なら、手動 .half() をやめて AMP を使う。計算を autocast の中に入れます。
with torch.cuda.amp.autocast():
out = model(x)
loss = criterion(out, y)
  1. 一部だけ half になっていないか確認する。カスタム層や自前で確保したバッファが float のまま残っていることがあります。

まとめ

  • 原因はモデルとデータの dtype の混在
  • 入力を x.half() でモデルに合わせる、または model.float() で戻す
  • 高速化目的なら手動 half をやめて autocast を使う
  • カスタム層やバッファの dtype も確認する