|
@@ -109,7 +109,7 @@ for epoch in range(5):
|
|
|
inputs = {k: v.to(device) for k, v in batch.items()}
|
|
|
outputs = model(**inputs)
|
|
|
logits = outputs.logits.squeeze()
|
|
|
- loss = loss_fn(logits, inputs["labels"].float())
|
|
|
+ loss = loss_fn(logits, inputs["labels"])
|
|
|
val_loss += loss.item()
|
|
|
|
|
|
print(f"Epoch {epoch + 1} | Val Loss: {val_loss / len(val_loader):.4f}")
|