luojunhui 7 ay önce
ebeveyn
işleme
3756471ab0
1 değiştirilmiş dosya ile 1 ekleme ve 1 silme
  1. 1 1
      train_model.py

+ 1 - 1
train_model.py

@@ -109,7 +109,7 @@ for epoch in range(5):
             inputs = {k: v.to(device) for k, v in batch.items()}
             outputs = model(**inputs)
             logits = outputs.logits.squeeze()
-            loss = loss_fn(logits, inputs["labels"].float())
+            loss = loss_fn(logits, inputs["labels"])
             val_loss += loss.item()
 
     print(f"Epoch {epoch + 1} | Val Loss: {val_loss / len(val_loader):.4f}")