TypeError: iteration over a 0-d tensor

.item()で取り出されたtensorの次元が0のために発生したエラーのため、下記の9から12行目のようにif c.ndim == 0で条件分岐させて対応。

with torch.no_grad():
 for data in test_loader:
  inputs, labels = data
  outputs = net(inputs)
  _, predicted = torch.max(outputs, 1)
  c = (predicted == labels).squeeze()
  for i in range(len(labels)):
   label = labels[i]
   if c.ndim == 0:
    class_correct[label.item()] += c.item()
   else:
    class_correct[label.item()] += c[i].item()
  class_total[label.item()] += 1

Uncategorized

Posted by vastee