liqian 1 年之前
父节点
当前提交
8797dabd3a
共有 1 个文件被更改,包括 8 次插入2 次删除
  1. 8 2
      ad_xgboost_train.py

+ 8 - 2
ad_xgboost_train.py

@@ -24,13 +24,19 @@ xgb_model = XGBClassifier(
     objective='binary:logistic',
     learning_rate=0.3,
     max_depth=10,
-    eval_metric='auc'
+    eval_metric=['error', 'logloss']
 )
-xgb_model.fit(x_train, y_train)
+xgb_model.fit(x_train, y_train, eval_set=[(x_train, y_train), (x_test, y_test)])
 # 5. 模型保存
 xgb_model.save_model('./data/ad_xgb.model')
 # 6. 测试集预测
 y_test_pre = xgb_model.predict(x_test)
+
+test_df = x_test.copy()
+test_df['y'] = y_test
+test_df['y_pre'] = y_test_pre
+test_df.to_csv('./data/test_pre.csv', index=False)
+
 # 7. 模型效果验证
 test_accuracy = metrics.accuracy_score(y_test, y_test_pre)
 print("Test Accuracy: %.2f%%" % (test_accuracy * 100.0))