train_model.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. """
  2. @author: luojunhui
  3. """
  4. import pandas as pd
  5. from sklearn.model_selection import train_test_split
  6. from transformers import BertTokenizer
  7. # 加载数据
  8. df = pd.read_csv("festival_data.csv")
  9. texts = df['text'].tolist()
  10. labels = df['label'].tolist()
  11. # 划分数据集
  12. train_texts, val_texts, train_labels, val_labels = train_test_split(
  13. texts, labels, test_size=0.2, stratify=labels) # 保证类别平衡[1,4](@ref)
  14. # 初始化分词器
  15. tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
  16. # 文本编码函数
  17. def encode_texts(text_list, max_len=32):
  18. return tokenizer(
  19. text_list,
  20. padding='max_length',
  21. truncation=True,
  22. max_length=max_len,
  23. return_tensors='pt'
  24. )
  25. # 编码训练集/验证集
  26. train_encodings = encode_texts(train_texts)
  27. val_encodings = encode_texts(val_texts)
  28. # 创建PyTorch数据集
  29. import torch
  30. from torch.utils.data import Dataset, DataLoader
  31. class FestivalDataset(Dataset):
  32. def __init__(self, encodings, labels):
  33. self.input_ids = encodings['input_ids']
  34. self.attention_mask = encodings['attention_mask']
  35. self.labels = torch.tensor(labels)
  36. def __getitem__(self, idx):
  37. return {
  38. 'input_ids': self.input_ids[idx],
  39. 'attention_mask': self.attention_mask[idx],
  40. 'labels': self.labels[idx]
  41. }
  42. def __len__(self):
  43. return len(self.labels)
  44. train_dataset = FestivalDataset(train_encodings, train_labels)
  45. val_dataset = FestivalDataset(val_encodings, val_labels)
  46. # 创建数据加载器
  47. train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
  48. val_loader = DataLoader(val_dataset, batch_size=64)
  49. from transformers import BertForSequenceClassification, AdamW
  50. import torch.nn as nn
  51. # 加载预训练模型
  52. model = BertForSequenceClassification.from_pretrained(
  53. 'bert-base-chinese',
  54. num_labels=1, # 二分类输出单个logit节点[2,6](@ref)
  55. problem_type="single_label_classification"
  56. )
  57. # 定义训练参数
  58. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  59. model.to(device)
  60. optimizer = AdamW(model.parameters(), lr=2e-5)
  61. loss_fn = nn.BCEWithLogitsLoss() # 适用于二分类
  62. # 训练循环
  63. for epoch in range(5):
  64. model.train()
  65. total_loss = 0
  66. for batch in train_loader:
  67. optimizer.zero_grad()
  68. inputs = {k: v.to(device) for k, v in batch.items()}
  69. outputs = model(**inputs)
  70. logits = outputs.logits.squeeze()
  71. loss = loss_fn(logits, inputs['labels'].float())
  72. loss.backward()
  73. optimizer.step()
  74. total_loss += loss.item()
  75. print(f"Epoch {epoch + 1} | Train Loss: {total_loss / len(train_loader):.4f}")
  76. # 验证阶段
  77. model.eval()
  78. val_loss = 0
  79. with torch.no_grad():
  80. for batch in val_loader:
  81. inputs = {k: v.to(device) for k, v in batch.items()}
  82. outputs = model(**inputs)
  83. logits = outputs.logits.squeeze()
  84. loss = loss_fn(logits, inputs['labels'].float())
  85. val_loss += loss.item()
  86. print(f"Epoch {epoch + 1} | Val Loss: {val_loss / len(val_loader):.4f}")
  87. # 保存完整模型
  88. torch.save(model, "festival_bert_model.pth")