luojunhui 4 ヶ月 前
コミット
d96da7b88f
1 ファイル変更24 行追加18 行削除
  1. 24 18
      train_model.py

+ 24 - 18
train_model.py

@@ -1,35 +1,39 @@
 """
 @author: luojunhui
 """
+
 import pandas as pd
 from sklearn.model_selection import train_test_split
 from transformers import BertTokenizer
 
 # 加载数据
 df = pd.read_csv("festival_data.csv")
-texts = df['text'].tolist()
-labels = df['label'].tolist()
+texts = df["text"].tolist()
+labels = df["label"].tolist()
 
 # 划分数据集
+print("开始划分数据集")
 train_texts, val_texts, train_labels, val_labels = train_test_split(
-    texts, labels, test_size=0.2, stratify=labels)  # 保证类别平衡[1,4](@ref)
+    texts, labels, test_size=0.2, stratify=labels
+)  # 保证类别平衡[1,4](@ref)
 
 # 初始化分词器
-tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
+tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
 
 
 # 文本编码函数
 def encode_texts(text_list, max_len=32):
     return tokenizer(
         text_list,
-        padding='max_length',
+        padding="max_length",
         truncation=True,
         max_length=max_len,
-        return_tensors='pt'
+        return_tensors="pt",
     )
 
 
 # 编码训练集/验证集
+print("开始编码训练集/验证集")
 train_encodings = encode_texts(train_texts)
 val_encodings = encode_texts(val_texts)
 
@@ -40,15 +44,15 @@ from torch.utils.data import Dataset, DataLoader
 
 class FestivalDataset(Dataset):
     def __init__(self, encodings, labels):
-        self.input_ids = encodings['input_ids']
-        self.attention_mask = encodings['attention_mask']
+        self.input_ids = encodings["input_ids"]
+        self.attention_mask = encodings["attention_mask"]
         self.labels = torch.tensor(labels)
 
     def __getitem__(self, idx):
         return {
-            'input_ids': self.input_ids[idx],
-            'attention_mask': self.attention_mask[idx],
-            'labels': self.labels[idx]
+            "input_ids": self.input_ids[idx],
+            "attention_mask": self.attention_mask[idx],
+            "labels": self.labels[idx],
         }
 
     def __len__(self):
@@ -66,20 +70,22 @@ from transformers import BertForSequenceClassification, AdamW
 import torch.nn as nn
 
 # 加载预训练模型
+print("预加载模型")
 model = BertForSequenceClassification.from_pretrained(
-    'bert-base-chinese',
-    num_labels=1,  # 二分类输出单个logit节点[2,6](@ref)
-    problem_type="single_label_classification"
+    "bert-base-chinese",
+    num_labels=2,  # 二分类输出单个logit节点[2,6](@ref)
+    problem_type="single_label_classification",
 )
 
 # 定义训练参数
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 model.to(device)
 optimizer = AdamW(model.parameters(), lr=2e-5)
-loss_fn = nn.BCEWithLogitsLoss()  # 适用于二分类
+loss_fn = torch.nn.CrossEntropyLoss()  # 适用于二分类
 
 # 训练循环
 for epoch in range(5):
+    print("开始第{}轮训练".format(epoch + 1))
     model.train()
     total_loss = 0
 
@@ -88,7 +94,7 @@ for epoch in range(5):
         inputs = {k: v.to(device) for k, v in batch.items()}
         outputs = model(**inputs)
         logits = outputs.logits.squeeze()
-        loss = loss_fn(logits, inputs['labels'].float())
+        loss = loss_fn(logits, inputs["labels"])
         loss.backward()
         optimizer.step()
         total_loss += loss.item()
@@ -103,7 +109,7 @@ for epoch in range(5):
             inputs = {k: v.to(device) for k, v in batch.items()}
             outputs = model(**inputs)
             logits = outputs.logits.squeeze()
-            loss = loss_fn(logits, inputs['labels'].float())
+            loss = loss_fn(logits, inputs["labels"].float())
             val_loss += loss.item()
 
     print(f"Epoch {epoch + 1} | Val Loss: {val_loss / len(val_loader):.4f}")