ad_xgboost_predict.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import numpy as np
  2. import xgboost as xgb
  3. from xgboost.sklearn import XGBClassifier
  4. from utils import RedisHelper
  5. from config import set_config
  6. redis_helper = RedisHelper()
  7. config_ = set_config()
  8. # # 模型加载
  9. # model = XGBClassifier()
  10. # booster = xgb.Booster()
  11. # booster.load_model('./data/ad_xgb.model')
  12. # model._Booster = booster
  13. def xgboost_predict(model, app_type, mid, video_id, abtest_id, ab_test_code):
  14. xgb_config = config_.AD_MODEL_CONFIG['xgb']
  15. # 1. 获取user特征
  16. user_feature_key = f"{xgb_config['predict_user_feature_key_prefix']}{app_type}:{mid}"
  17. user_feature = redis_helper.get_data_from_redis(key_name=user_feature_key)
  18. if user_feature is None:
  19. user_feature_key = f"{xgb_config['predict_user_feature_key_prefix']}{app_type}:-1"
  20. user_feature = redis_helper.get_data_from_redis(key_name=user_feature_key)
  21. user_feature = eval(user_feature)
  22. # 2. 获取video特征
  23. video_feature_key = f"{xgb_config['predict_video_feature_key_prefix']}{app_type}:{video_id}"
  24. video_feature = redis_helper.get_data_from_redis(key_name=video_feature_key)
  25. if video_feature is None:
  26. video_feature_key = f"{xgb_config['predict_video_feature_key_prefix']}{app_type}:-1"
  27. video_feature = redis_helper.get_data_from_redis(key_name=video_feature_key)
  28. video_feature = eval(video_feature)
  29. # 3. 拼接不出广告时的特征 & 预测
  30. ad_feature_0 = user_feature + video_feature + [0]
  31. # ad_0_predict = model.predict_proba(np.array([ad_feature_0]))
  32. # ad_0_predict = ad_0_predict[0][1]
  33. # ad_0_predict = 0.7
  34. # 4. 拼接出广告时的特征 & 预测
  35. ad_feature_1 = user_feature + video_feature + [1]
  36. # ad_1_predict = model.predict_proba(np.array([ad_feature_1]))
  37. # ad_1_predict = ad_1_predict[0][1]
  38. # ad_1_predict = 0.6
  39. # 预测
  40. predict = model.predict_proba(np.array([ad_feature_0, ad_feature_1]))
  41. ad_0_predict, ad_1_predict = predict[0][1], predict[1][1]
  42. # 5. 作差
  43. predict_res = ad_0_predict - ad_1_predict
  44. # 6. 获取阈值
  45. threshold_key_name = f"{xgb_config['threshold_key_prefix']}{abtest_id}:{ab_test_code}"
  46. threshold = redis_helper.get_data_from_redis(key_name=threshold_key_name)
  47. if threshold is None:
  48. threshold = 0
  49. else:
  50. threshold = float(threshold)
  51. # 7. 阈值判断
  52. if predict_res > threshold:
  53. # 大于阈值,不出广告
  54. ad_predict = 1
  55. else:
  56. # 否则,出广告
  57. ad_predict = 2
  58. result = {
  59. 'predict_tag': 'xgboost',
  60. 'ad_0_predict': ad_0_predict,
  61. 'ad_1_predict': ad_1_predict,
  62. 'predict_res': predict_res,
  63. 'threshold': threshold,
  64. 'ad_predict': ad_predict}
  65. return result