train_model.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. """
  2. @author: luojunhui
  3. """
  4. import pandas as pd
  5. from sklearn.model_selection import train_test_split
  6. from transformers import BertTokenizer
  7. # 加载数据
  8. df = pd.read_csv("festival_data.csv")
  9. texts = df["text"].tolist()
  10. labels = df["label"].tolist()
  11. # 划分数据集
  12. print("开始划分数据集")
  13. train_texts, val_texts, train_labels, val_labels = train_test_split(
  14. texts, labels, test_size=0.2, stratify=labels
  15. ) # 保证类别平衡[1,4](@ref)
  16. # 初始化分词器
  17. tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
  18. # 文本编码函数
  19. def encode_texts(text_list, max_len=32):
  20. return tokenizer(
  21. text_list,
  22. padding="max_length",
  23. truncation=True,
  24. max_length=max_len,
  25. return_tensors="pt",
  26. )
  27. # 编码训练集/验证集
  28. print("开始编码训练集/验证集")
  29. train_encodings = encode_texts(train_texts)
  30. val_encodings = encode_texts(val_texts)
  31. # 创建PyTorch数据集
  32. import torch
  33. from torch.utils.data import Dataset, DataLoader
  34. class FestivalDataset(Dataset):
  35. def __init__(self, encodings, labels):
  36. self.input_ids = encodings["input_ids"]
  37. self.attention_mask = encodings["attention_mask"]
  38. self.labels = torch.tensor(labels, dtype=torch.long)
  39. def __getitem__(self, idx):
  40. return {
  41. "input_ids": self.input_ids[idx],
  42. "attention_mask": self.attention_mask[idx],
  43. "labels": self.labels[idx],
  44. }
  45. def __len__(self):
  46. return len(self.labels)
  47. train_dataset = FestivalDataset(train_encodings, train_labels)
  48. val_dataset = FestivalDataset(val_encodings, val_labels)
  49. # 创建数据加载器
  50. train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
  51. val_loader = DataLoader(val_dataset, batch_size=64)
  52. from transformers import BertForSequenceClassification, AdamW
  53. import torch.nn as nn
  54. # 加载预训练模型
  55. print("预加载模型")
  56. model = BertForSequenceClassification.from_pretrained(
  57. "bert-base-chinese",
  58. num_labels=2, # 二分类输出单个logit节点[2,6](@ref)
  59. problem_type="single_label_classification",
  60. )
  61. # 定义训练参数
  62. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  63. model.to(device)
  64. optimizer = AdamW(model.parameters(), lr=2e-5)
  65. loss_fn = torch.nn.CrossEntropyLoss() # 适用于二分类
  66. # 训练循环
  67. for epoch in range(5):
  68. print("开始第{}轮训练".format(epoch + 1))
  69. model.train()
  70. total_loss = 0
  71. for batch in train_loader:
  72. optimizer.zero_grad()
  73. inputs = {k: v.to(device) for k, v in batch.items()}
  74. outputs = model(**inputs)
  75. logits = outputs.logits.squeeze()
  76. loss = loss_fn(logits, inputs["labels"])
  77. loss.backward()
  78. optimizer.step()
  79. total_loss += loss.item()
  80. print(f"Epoch {epoch + 1} | Train Loss: {total_loss / len(train_loader):.4f}")
  81. # 验证阶段
  82. model.eval()
  83. val_loss = 0
  84. with torch.no_grad():
  85. for batch in val_loader:
  86. inputs = {k: v.to(device) for k, v in batch.items()}
  87. outputs = model(**inputs)
  88. logits = outputs.logits.squeeze()
  89. loss = loss_fn(logits, inputs["labels"].float())
  90. val_loss += loss.item()
  91. print(f"Epoch {epoch + 1} | Val Loss: {val_loss / len(val_loader):.4f}")
  92. # 保存完整模型
  93. torch.save(model, "festival_bert_model.pth")