|
@@ -5,6 +5,11 @@
|
|
|
import pandas as pd
|
|
|
from sklearn.model_selection import train_test_split
|
|
|
from transformers import BertTokenizer
|
|
|
+from transformers import BertForSequenceClassification, AdamW
|
|
|
+
|
|
|
+# 创建PyTorch数据集
|
|
|
+import torch
|
|
|
+from torch.utils.data import Dataset, DataLoader
|
|
|
|
|
|
# 加载数据
|
|
|
df = pd.read_csv("festival_data.csv")
|
|
@@ -15,7 +20,7 @@ labels = df["label"].tolist()
|
|
|
print("开始划分数据集")
|
|
|
train_texts, val_texts, train_labels, val_labels = train_test_split(
|
|
|
texts, labels, test_size=0.2, stratify=labels
|
|
|
-) # 保证类别平衡[1,4](@ref)
|
|
|
+)
|
|
|
|
|
|
# 初始化分词器
|
|
|
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
|
|
@@ -37,10 +42,6 @@ print("开始编码训练集/验证集")
|
|
|
train_encodings = encode_texts(train_texts)
|
|
|
val_encodings = encode_texts(val_texts)
|
|
|
|
|
|
-# 创建PyTorch数据集
|
|
|
-import torch
|
|
|
-from torch.utils.data import Dataset, DataLoader
|
|
|
-
|
|
|
|
|
|
class FestivalDataset(Dataset):
|
|
|
def __init__(self, encodings, labels):
|
|
@@ -66,9 +67,6 @@ val_dataset = FestivalDataset(val_encodings, val_labels)
|
|
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
|
|
val_loader = DataLoader(val_dataset, batch_size=64)
|
|
|
|
|
|
-from transformers import BertForSequenceClassification, AdamW
|
|
|
-import torch.nn as nn
|
|
|
-
|
|
|
# 加载预训练模型
|
|
|
print("预加载模型")
|
|
|
model = BertForSequenceClassification.from_pretrained(
|