luojunhui 7 月之前
父节点
当前提交
3756471ab0
共有 1 个文件被更改,包括 1 次插入1 次删除
  1. 1 1
      train_model.py

+ 1 - 1
train_model.py

@@ -109,7 +109,7 @@ for epoch in range(5):
             inputs = {k: v.to(device) for k, v in batch.items()}
             inputs = {k: v.to(device) for k, v in batch.items()}
             outputs = model(**inputs)
             outputs = model(**inputs)
             logits = outputs.logits.squeeze()
             logits = outputs.logits.squeeze()
-            loss = loss_fn(logits, inputs["labels"].float())
+            loss = loss_fn(logits, inputs["labels"])
             val_loss += loss.item()
             val_loss += loss.item()
 
 
     print(f"Epoch {epoch + 1} | Val Loss: {val_loss / len(val_loader):.4f}")
     print(f"Epoch {epoch + 1} | Val Loss: {val_loss / len(val_loader):.4f}")