luojunhui 4 ماه پیش
والد
کامیت
df262c3391
1فایلهای تغییر یافته به همراه1 افزوده شده و 1 حذف شده
  1. 1 1
      train_model.py

+ 1 - 1
train_model.py

@@ -46,7 +46,7 @@ class FestivalDataset(Dataset):
     def __init__(self, encodings, labels):
         self.input_ids = encodings["input_ids"]
         self.attention_mask = encodings["attention_mask"]
-        self.labels = torch.tensor(labels)
+        self.labels = torch.tensor(labels, dtype=torch.long)
 
     def __getitem__(self, idx):
         return {