luojunhui 4 months ago
parent
commit
8c7dfddcee
1 changed files with 38 additions and 0 deletions
  1. 38 0
      train_model.py

+ 38 - 0
train_model.py

@@ -116,3 +116,41 @@ for epoch in range(5):
 
 # 保存完整模型
 torch.save(model, "festival_bert_model.pth")
+
+from sklearn.metrics import roc_auc_score, roc_curve
+import numpy as np
+
+
+def calculate_auc(model, dataloader):
+    model.eval()
+    true_labels = []
+    pred_probs = []
+
+    with torch.no_grad():
+        for batch in dataloader:
+            inputs = {k: v.to(device) for k, v in batch.items()}
+            outputs = model(**inputs)
+            probs = torch.sigmoid(outputs.logits.squeeze()).cpu().numpy()
+            pred_probs.extend(probs)
+            true_labels.extend(inputs['labels'].cpu().numpy())
+
+    auc_score = roc_auc_score(true_labels, pred_probs)
+    fpr, tpr, _ = roc_curve(true_labels, pred_probs)
+    return auc_score, fpr, tpr
+
+
+print("评估模型")
+# 计算验证集AUC
+auc_val, fpr, tpr = calculate_auc(model, val_loader)
+print(f"Validation AUC: {auc_val:.4f}")
+
+# 可视化ROC曲线
+import matplotlib.pyplot as plt
+
+plt.plot(fpr, tpr, label=f'AUC = {auc_val:.2f}')
+plt.plot([0, 1], [0, 1], linestyle='--')
+plt.xlabel('False Positive Rate')
+plt.ylabel('True Positive Rate')
+plt.title('ROC Curve')
+plt.legend()
+plt.show()