luojunhui 4 meses atrás
pai
commit
2fc4d94886
1 arquivos alterados com 2 adições e 10 exclusões
  1. 2 10
      train_model.py

+ 2 - 10
train_model.py

@@ -148,13 +148,5 @@ print("评估模型")
 auc_val, fpr, tpr = calculate_auc(model, val_loader)
 print(f"Validation AUC: {auc_val:.4f}")
 
-# 可视化ROC曲线
-import matplotlib.pyplot as plt
-
-plt.plot(fpr, tpr, label=f'AUC = {auc_val:.2f}')
-plt.plot([0, 1], [0, 1], linestyle='--')
-plt.xlabel('False Positive Rate')
-plt.ylabel('True Positive Rate')
-plt.title('ROC Curve')
-plt.legend()
-plt.show()
+# 检查数据是否泄露(训练集和测试集重叠)
+assert len(set(train_texts) & set(val_texts)) == 0, "存在数据泄露!"