luojunhui 7 meses atrás
pai
commit
df262c3391
1 arquivos alterados com 1 adições e 1 exclusões
  1. 1 1
      train_model.py

+ 1 - 1
train_model.py

@@ -46,7 +46,7 @@ class FestivalDataset(Dataset):
     def __init__(self, encodings, labels):
         self.input_ids = encodings["input_ids"]
         self.attention_mask = encodings["attention_mask"]
-        self.labels = torch.tensor(labels)
+        self.labels = torch.tensor(labels, dtype=torch.long)
 
     def __getitem__(self, idx):
         return {