luojunhui hai 7 meses
pai
achega
df262c3391
Modificáronse 1 ficheiros con 1 adicións e 1 borrados
  1. 1 1
      train_model.py

+ 1 - 1
train_model.py

@@ -46,7 +46,7 @@ class FestivalDataset(Dataset):
     def __init__(self, encodings, labels):
         self.input_ids = encodings["input_ids"]
         self.attention_mask = encodings["attention_mask"]
-        self.labels = torch.tensor(labels)
+        self.labels = torch.tensor(labels, dtype=torch.long)
 
     def __getitem__(self, idx):
         return {