|
@@ -116,3 +116,41 @@ for epoch in range(5):
|
|
|
|
|
|
# 保存完整模型
|
|
|
torch.save(model, "festival_bert_model.pth")
|
|
|
+
|
|
|
+from sklearn.metrics import roc_auc_score, roc_curve
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+
|
|
|
+def calculate_auc(model, dataloader):
|
|
|
+ model.eval()
|
|
|
+ true_labels = []
|
|
|
+ pred_probs = []
|
|
|
+
|
|
|
+ with torch.no_grad():
|
|
|
+ for batch in dataloader:
|
|
|
+ inputs = {k: v.to(device) for k, v in batch.items()}
|
|
|
+ outputs = model(**inputs)
|
|
|
+ probs = torch.sigmoid(outputs.logits.squeeze()).cpu().numpy()
|
|
|
+ pred_probs.extend(probs)
|
|
|
+ true_labels.extend(inputs['labels'].cpu().numpy())
|
|
|
+
|
|
|
+ auc_score = roc_auc_score(true_labels, pred_probs)
|
|
|
+ fpr, tpr, _ = roc_curve(true_labels, pred_probs)
|
|
|
+ return auc_score, fpr, tpr
|
|
|
+
|
|
|
+
|
|
|
+print("评估模型")
|
|
|
+# 计算验证集AUC
|
|
|
+auc_val, fpr, tpr = calculate_auc(model, val_loader)
|
|
|
+print(f"Validation AUC: {auc_val:.4f}")
|
|
|
+
|
|
|
+# 可视化ROC曲线
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+
|
|
|
+plt.plot(fpr, tpr, label=f'AUC = {auc_val:.2f}')
|
|
|
+plt.plot([0, 1], [0, 1], linestyle='--')
|
|
|
+plt.xlabel('False Positive Rate')
|
|
|
+plt.ylabel('True Positive Rate')
|
|
|
+plt.title('ROC Curve')
|
|
|
+plt.legend()
|
|
|
+plt.show()
|