luojunhui 7 hónapja
szülő
commit
df262c3391
1 módosított fájl, 1 hozzáadás és 1 törlés
  1. 1 1
      train_model.py

+ 1 - 1
train_model.py

@@ -46,7 +46,7 @@ class FestivalDataset(Dataset):
     def __init__(self, encodings, labels):
         self.input_ids = encodings["input_ids"]
         self.attention_mask = encodings["attention_mask"]
-        self.labels = torch.tensor(labels)
+        self.labels = torch.tensor(labels, dtype=torch.long)
 
     def __getitem__(self, idx):
         return {