train_model.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. """
  2. @author: luojunhui
  3. """
  4. import pandas as pd
  5. from sklearn.model_selection import train_test_split
  6. from transformers import BertTokenizer
  7. from transformers import BertForSequenceClassification, AdamW
  8. # 创建PyTorch数据集
  9. import torch
  10. from torch.utils.data import Dataset, DataLoader
  11. # 加载数据
  12. df = pd.read_csv("festival_data.csv")
  13. texts = df["text"].tolist()
  14. labels = df["label"].tolist()
  15. # 划分数据集
  16. print("开始划分数据集")
  17. train_texts, val_texts, train_labels, val_labels = train_test_split(
  18. texts, labels, test_size=0.2, stratify=labels
  19. )
  20. # 初始化分词器
  21. tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
  22. # 文本编码函数
  23. def encode_texts(text_list, max_len=32):
  24. return tokenizer(
  25. text_list,
  26. padding="max_length",
  27. truncation=True,
  28. max_length=max_len,
  29. return_tensors="pt",
  30. )
  31. # 编码训练集/验证集
  32. print("开始编码训练集/验证集")
  33. train_encodings = encode_texts(train_texts)
  34. val_encodings = encode_texts(val_texts)
  35. class FestivalDataset(Dataset):
  36. def __init__(self, encodings, labels):
  37. self.input_ids = encodings["input_ids"]
  38. self.attention_mask = encodings["attention_mask"]
  39. self.labels = torch.tensor(labels, dtype=torch.long)
  40. def __getitem__(self, idx):
  41. return {
  42. "input_ids": self.input_ids[idx],
  43. "attention_mask": self.attention_mask[idx],
  44. "labels": self.labels[idx],
  45. }
  46. def __len__(self):
  47. return len(self.labels)
  48. train_dataset = FestivalDataset(train_encodings, train_labels)
  49. val_dataset = FestivalDataset(val_encodings, val_labels)
  50. # 创建数据加载器
  51. train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
  52. val_loader = DataLoader(val_dataset, batch_size=64)
  53. # 加载预训练模型
  54. print("预加载模型")
  55. model = BertForSequenceClassification.from_pretrained(
  56. "bert-base-chinese",
  57. num_labels=2, # 二分类输出单个logit节点[2,6](@ref)
  58. problem_type="single_label_classification",
  59. )
  60. # 定义训练参数
  61. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  62. model.to(device)
  63. optimizer = AdamW(model.parameters(), lr=2e-5)
  64. loss_fn = torch.nn.CrossEntropyLoss() # 适用于二分类
  65. # 训练循环
  66. for epoch in range(5):
  67. print("开始第{}轮训练".format(epoch + 1))
  68. model.train()
  69. total_loss = 0
  70. for batch in train_loader:
  71. optimizer.zero_grad()
  72. inputs = {k: v.to(device) for k, v in batch.items()}
  73. outputs = model(**inputs)
  74. logits = outputs.logits.squeeze()
  75. loss = loss_fn(logits, inputs["labels"])
  76. loss.backward()
  77. optimizer.step()
  78. total_loss += loss.item()
  79. print(f"Epoch {epoch + 1} | Train Loss: {total_loss / len(train_loader):.4f}")
  80. # 验证阶段
  81. model.eval()
  82. val_loss = 0
  83. with torch.no_grad():
  84. for batch in val_loader:
  85. inputs = {k: v.to(device) for k, v in batch.items()}
  86. outputs = model(**inputs)
  87. logits = outputs.logits.squeeze()
  88. loss = loss_fn(logits, inputs["labels"])
  89. val_loss += loss.item()
  90. print(f"Epoch {epoch + 1} | Val Loss: {val_loss / len(val_loader):.4f}")
  91. # 保存完整模型
  92. torch.save(model, "festival_bert_model.pth")
  93. from sklearn.metrics import roc_auc_score, roc_curve
  94. import numpy as np
  95. def calculate_auc(model, dataloader):
  96. model.eval()
  97. true_labels = []
  98. pred_probs = []
  99. with torch.no_grad():
  100. for batch in dataloader:
  101. inputs = {k: v.to(device) for k, v in batch.items()}
  102. outputs = model(**inputs)
  103. # 修改部分:使用Softmax获取概率并提取正类概率
  104. probabilities = torch.softmax(outputs.logits, dim=1) # 对logits做Softmax
  105. positive_class_probs = probabilities[:, 1].cpu().numpy() # 提取第二列(正类)
  106. pred_probs.extend(positive_class_probs)
  107. true_labels.extend(inputs['labels'].cpu().numpy())
  108. auc_score = roc_auc_score(true_labels, pred_probs)
  109. fpr, tpr, _ = roc_curve(true_labels, pred_probs)
  110. return auc_score, fpr, tpr
  111. print("评估模型")
  112. # 计算验证集AUC
  113. auc_val, fpr, tpr = calculate_auc(model, val_loader)
  114. print(f"Validation AUC: {auc_val:.4f}")
  115. # 检查数据是否泄露(训练集和测试集重叠)
  116. assert len(set(train_texts) & set(val_texts)) == 0, "存在数据泄露!"