liqian 1 year ago
parent
commit
f8dc773593
1 changed files with 4 additions and 2 deletions
  1. 4 2
      ad_xgboost_threshold_update.py

+ 4 - 2
ad_xgboost_threshold_update.py

@@ -90,18 +90,20 @@ def threshold_update(project, table, dt, app_type):
     print(f"feature_initial_df shape: {feature_initial_df.shape}")
     # 获取所需的字段
     predict_df = feature_initial_df[features[4:]]
+    print(f"predict_df shape: {predict_df.shape}")
 
     # 2. 不出广告情况下的预测
     predict_df_0 = predict_df.copy()
     predict_df_0['ad_status'] = 0
     y_pred_proba_0 = model.predict_proba(predict_df_0)
-    predict_df['y_0'] = [x[1] for x in y_pred_proba_0]
-    print(f"predict_df shape: {predict_df.shape}")
 
     # 3. 出广告情况下的预测
     predict_df_1 = predict_df.copy()
     predict_df_1['ad_status'] = 1
     y_pred_proba_1 = model.predict_proba(predict_df_1)
+
+    predict_df['y_0'] = [x[1] for x in y_pred_proba_0]
+    print(f"predict_df shape: {predict_df.shape}")
     predict_df['y_1'] = [x[1] for x in y_pred_proba_1]
     print(f"predict_df shape: {predict_df.shape}")