luojunhui 4 months ago
parent
commit
3756471ab0
1 changed files with 1 additions and 1 deletions
  1. 1 1
      train_model.py

+ 1 - 1
train_model.py

@@ -109,7 +109,7 @@ for epoch in range(5):
             inputs = {k: v.to(device) for k, v in batch.items()}
             outputs = model(**inputs)
             logits = outputs.logits.squeeze()
-            loss = loss_fn(logits, inputs["labels"].float())
+            loss = loss_fn(logits, inputs["labels"])
             val_loss += loss.item()
 
     print(f"Epoch {epoch + 1} | Val Loss: {val_loss / len(val_loader):.4f}")