|
@@ -46,7 +46,7 @@ class FestivalDataset(Dataset):
|
|
|
def __init__(self, encodings, labels):
|
|
|
self.input_ids = encodings["input_ids"]
|
|
|
self.attention_mask = encodings["attention_mask"]
|
|
|
- self.labels = torch.tensor(labels)
|
|
|
+ self.labels = torch.tensor(labels, dtype=torch.long)
|
|
|
|
|
|
def __getitem__(self, idx):
|
|
|
return {
|