INFO:root:Loading Dataset...
>> train dataset: 8
>> train_val dataset: 8
>> test dataset: 9
INFO:root:Loading Network...
INFO:root:optimizer Done
INFO:root:Beginning Epoch 00, lr=0.001
Traceback (most recent call last):
File "train_ggcnn.py", line 323, in <module>
run()
File "train_ggcnn.py", line 273, in run
train_results = train(epoch, net, device, train_data, optimizer)
File "train_ggcnn.py", line 150, in train
lossd = focal_loss(net, x.to(device), y[0].to(device), y[1].to(device), y[2].to(device), y[3].to(device))
File "/root/ggcnn1/models/loss.py", line 65, in focal_loss
loss_pos = bin_focal_loss(pred_pos, y_pos, alpha=0.9) * 10
File "/root/ggcnn1/models/loss.py", line 40, in bin_focal_loss
zeros_loc = torch.where(target == 0)
TypeError: where() missing 2 required positional argument: "input", "other"
博主您好,请问 这个错误怎么解决。
评论(1)
您还未登录,请登录后发表或查看评论
第三方账号登入
QQ 微博 微信