luojunhui 4 months ago
parent
commit
26c242c0e6
1 changed files with 6 additions and 8 deletions
  1. 6 8
      train_model.py

+ 6 - 8
train_model.py

@@ -5,6 +5,11 @@
 import pandas as pd
 from sklearn.model_selection import train_test_split
 from transformers import BertTokenizer
+from transformers import BertForSequenceClassification, AdamW
+
+# 创建PyTorch数据集
+import torch
+from torch.utils.data import Dataset, DataLoader
 
 # 加载数据
 df = pd.read_csv("festival_data.csv")
@@ -15,7 +20,7 @@ labels = df["label"].tolist()
 print("开始划分数据集")
 train_texts, val_texts, train_labels, val_labels = train_test_split(
     texts, labels, test_size=0.2, stratify=labels
-)  # 保证类别平衡[1,4](@ref)
+)
 
 # 初始化分词器
 tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
@@ -37,10 +42,6 @@ print("开始编码训练集/验证集")
 train_encodings = encode_texts(train_texts)
 val_encodings = encode_texts(val_texts)
 
-# 创建PyTorch数据集
-import torch
-from torch.utils.data import Dataset, DataLoader
-
 
 class FestivalDataset(Dataset):
     def __init__(self, encodings, labels):
@@ -66,9 +67,6 @@ val_dataset = FestivalDataset(val_encodings, val_labels)
 train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
 val_loader = DataLoader(val_dataset, batch_size=64)
 
-from transformers import BertForSequenceClassification, AdamW
-import torch.nn as nn
-
 # 加载预训练模型
 print("预加载模型")
 model = BertForSequenceClassification.from_pretrained(