ad_xgboost_predict.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import numpy as np
  2. import xgboost as xgb
  3. from xgboost.sklearn import XGBClassifier
  4. from utils import RedisHelper
  5. from config import set_config
  6. redis_helper = RedisHelper()
  7. config_ = set_config()
  8. # 模型加载
  9. model = XGBClassifier()
  10. booster = xgb.Booster()
  11. booster.load_model('./data/ad_xgb.model')
  12. model._Booster = booster
  13. def xgboost_predict(app_type, mid, video_id, abtest_id, ab_test_code):
  14. xgb_config = config_.AD_MODEL_CONFIG['xgb']
  15. # 1. 获取user特征
  16. user_feature_key = f"{xgb_config['predict_user_feature_key_prefix']}{app_type}:{mid}"
  17. user_feature = redis_helper.get_data_from_redis(key_name=user_feature_key)
  18. if user_feature is None:
  19. user_feature_key = f"{xgb_config['predict_user_feature_key_prefix']}{app_type}:-1"
  20. user_feature = redis_helper.get_data_from_redis(key_name=user_feature_key)
  21. user_feature = eval(user_feature)
  22. # 2. 获取video特征
  23. video_feature_key = f"{xgb_config['predict_video_feature_key_prefix']}{app_type}:{video_id}"
  24. video_feature = redis_helper.get_data_from_redis(key_name=video_feature_key)
  25. if video_feature is None:
  26. video_feature_key = f"{xgb_config['predict_video_feature_key_prefix']}{app_type}:-1"
  27. video_feature = redis_helper.get_data_from_redis(key_name=video_feature_key)
  28. video_feature = eval(video_feature)
  29. # 3. 拼接出广告时的特征 & 预测
  30. ad_feature_0 = user_feature + video_feature + [0]
  31. ad_0_predict = model.predict_proba(np.array([ad_feature_0]))
  32. ad_0_predict = ad_0_predict[0][1]
  33. # 4. 拼接不出广告时的特征 & 预测
  34. ad_feature_1 = user_feature + video_feature + [1]
  35. ad_1_predict = model.predict_proba(np.array([ad_feature_1]))
  36. ad_1_predict = ad_1_predict[0][1]
  37. # 5. 作差
  38. predict_res = ad_0_predict - ad_1_predict
  39. # 6. 获取阈值
  40. threshold_key_name = f"{xgb_config['threshold_key_prefix']}{abtest_id}:{ab_test_code}"
  41. threshold = redis_helper.get_data_from_redis(key_name=threshold_key_name)
  42. if threshold is None:
  43. threshold = 0
  44. else:
  45. threshold = float(threshold)
  46. # 7. 阈值判断
  47. if predict_res > threshold:
  48. # 大于阈值,不出广告
  49. ad_predict = 1
  50. else:
  51. # 否则,出广告
  52. ad_predict = 2
  53. result = {
  54. 'predict_tag': 'xgboost',
  55. 'ad_0_predict': ad_0_predict,
  56. 'ad_1_predict': ad_1_predict,
  57. 'predict_res': predict_res,
  58. 'threshold': threshold,
  59. 'ad_predict': ad_predict}
  60. return result