|
@@ -148,13 +148,5 @@ print("评估模型")
|
|
|
auc_val, fpr, tpr = calculate_auc(model, val_loader)
|
|
|
print(f"Validation AUC: {auc_val:.4f}")
|
|
|
|
|
|
-# 可视化ROC曲线
|
|
|
-import matplotlib.pyplot as plt
|
|
|
-
|
|
|
-plt.plot(fpr, tpr, label=f'AUC = {auc_val:.2f}')
|
|
|
-plt.plot([0, 1], [0, 1], linestyle='--')
|
|
|
-plt.xlabel('False Positive Rate')
|
|
|
-plt.ylabel('True Positive Rate')
|
|
|
-plt.title('ROC Curve')
|
|
|
-plt.legend()
|
|
|
-plt.show()
|
|
|
+# 检查数据是否泄露(训练集和测试集重叠)
|
|
|
+assert len(set(train_texts) & set(val_texts)) == 0, "存在数据泄露!"
|