123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- """
- @author: luojunhui
- """
- import pandas as pd
- from sklearn.model_selection import train_test_split
- from transformers import BertTokenizer
- from transformers import BertForSequenceClassification, AdamW
- # 创建PyTorch数据集
- import torch
- from torch.utils.data import Dataset, DataLoader
- # 加载数据
- df = pd.read_csv("festival_data.csv")
- texts = df["text"].tolist()
- labels = df["label"].tolist()
- # 划分数据集
- print("开始划分数据集")
- train_texts, val_texts, train_labels, val_labels = train_test_split(
- texts, labels, test_size=0.2, stratify=labels
- )
- # 初始化分词器
- tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
- # 文本编码函数
- def encode_texts(text_list, max_len=32):
- return tokenizer(
- text_list,
- padding="max_length",
- truncation=True,
- max_length=max_len,
- return_tensors="pt",
- )
- # 编码训练集/验证集
- print("开始编码训练集/验证集")
- train_encodings = encode_texts(train_texts)
- val_encodings = encode_texts(val_texts)
- class FestivalDataset(Dataset):
- def __init__(self, encodings, labels):
- self.input_ids = encodings["input_ids"]
- self.attention_mask = encodings["attention_mask"]
- self.labels = torch.tensor(labels, dtype=torch.long)
- def __getitem__(self, idx):
- return {
- "input_ids": self.input_ids[idx],
- "attention_mask": self.attention_mask[idx],
- "labels": self.labels[idx],
- }
- def __len__(self):
- return len(self.labels)
- train_dataset = FestivalDataset(train_encodings, train_labels)
- val_dataset = FestivalDataset(val_encodings, val_labels)
- # 创建数据加载器
- train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
- val_loader = DataLoader(val_dataset, batch_size=64)
- # 加载预训练模型
- print("预加载模型")
- model = BertForSequenceClassification.from_pretrained(
- "bert-base-chinese",
- num_labels=2, # 二分类输出单个logit节点[2,6](@ref)
- problem_type="single_label_classification",
- )
- # 定义训练参数
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- model.to(device)
- optimizer = AdamW(model.parameters(), lr=2e-5)
- loss_fn = torch.nn.CrossEntropyLoss() # 适用于二分类
- # 训练循环
- for epoch in range(5):
- print("开始第{}轮训练".format(epoch + 1))
- model.train()
- total_loss = 0
- for batch in train_loader:
- optimizer.zero_grad()
- inputs = {k: v.to(device) for k, v in batch.items()}
- outputs = model(**inputs)
- logits = outputs.logits.squeeze()
- loss = loss_fn(logits, inputs["labels"])
- loss.backward()
- optimizer.step()
- total_loss += loss.item()
- print(f"Epoch {epoch + 1} | Train Loss: {total_loss / len(train_loader):.4f}")
- # 验证阶段
- model.eval()
- val_loss = 0
- with torch.no_grad():
- for batch in val_loader:
- inputs = {k: v.to(device) for k, v in batch.items()}
- outputs = model(**inputs)
- logits = outputs.logits.squeeze()
- loss = loss_fn(logits, inputs["labels"])
- val_loss += loss.item()
- print(f"Epoch {epoch + 1} | Val Loss: {val_loss / len(val_loader):.4f}")
- # 保存完整模型
- torch.save(model, "festival_bert_model.pth")
- from sklearn.metrics import roc_auc_score, roc_curve
- import numpy as np
- def calculate_auc(model, dataloader):
- model.eval()
- true_labels = []
- pred_probs = []
- with torch.no_grad():
- for batch in dataloader:
- inputs = {k: v.to(device) for k, v in batch.items()}
- outputs = model(**inputs)
- # 修改部分:使用Softmax获取概率并提取正类概率
- probabilities = torch.softmax(outputs.logits, dim=1) # 对logits做Softmax
- positive_class_probs = probabilities[:, 1].cpu().numpy() # 提取第二列(正类)
- pred_probs.extend(positive_class_probs)
- true_labels.extend(inputs['labels'].cpu().numpy())
- auc_score = roc_auc_score(true_labels, pred_probs)
- fpr, tpr, _ = roc_curve(true_labels, pred_probs)
- return auc_score, fpr, tpr
- print("评估模型")
- # 计算验证集AUC
- auc_val, fpr, tpr = calculate_auc(model, val_loader)
- print(f"Validation AUC: {auc_val:.4f}")
- # 检查数据是否泄露(训练集和测试集重叠)
- assert len(set(train_texts) & set(val_texts)) == 0, "存在数据泄露!"
|