INFO:root:Loading Dataset...
>> train dataset: 8
>> train_val dataset: 8
>> test dataset: 9
INFO:root:Loading Network...
INFO:root:optimizer Done
INFO:root:Beginning Epoch 00, lr=0.001
Traceback (most recent call last):
  File "train_ggcnn.py", line 323, in <module>
    run()
  File "train_ggcnn.py", line 273, in run
    train_results = train(epoch, net, device, train_data, optimizer)
  File "train_ggcnn.py", line 150, in train
    lossd = focal_loss(net, x.to(device), y[0].to(device), y[1].to(device), y[2].to(device), y[3].to(device))
  File "/root/ggcnn1/models/loss.py", line 65, in focal_loss
    loss_pos = bin_focal_loss(pred_pos, y_pos, alpha=0.9) * 10
  File "/root/ggcnn1/models/loss.py", line 40, in bin_focal_loss
    zeros_loc = torch.where(target == 0)
TypeError: where() missing 2 required positional argument: "input", "other"