ad_xgboost_predict.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import numpy as np
  2. import xgboost as xgb
  3. from xgboost.sklearn import XGBClassifier
  4. from utils import RedisHelper
  5. from config import set_config
  6. redis_helper = RedisHelper()
  7. config_ = set_config()
  8. # # 模型加载
  9. # model = XGBClassifier()
  10. # booster = xgb.Booster()
  11. # booster.load_model('./data/ad_xgb.model')
  12. # model._Booster = booster
  13. def xgboost_predict(model, app_type, mid, video_id, abtest_id, ab_test_code):
  14. xgb_config = config_.AD_MODEL_CONFIG['xgb']
  15. # 1. 获取user特征
  16. user_feature_key = f"{xgb_config['predict_user_feature_key_prefix']}{app_type}:{mid}"
  17. user_feature = redis_helper.get_data_from_redis(key_name=user_feature_key)
  18. if user_feature is None:
  19. user_feature_key = f"{xgb_config['predict_user_feature_key_prefix']}{app_type}:-1"
  20. user_feature = redis_helper.get_data_from_redis(key_name=user_feature_key)
  21. user_feature = eval(user_feature)
  22. # 2. 获取video特征
  23. video_feature_key = f"{xgb_config['predict_video_feature_key_prefix']}{app_type}:{video_id}"
  24. video_feature = redis_helper.get_data_from_redis(key_name=video_feature_key)
  25. if video_feature is None:
  26. video_feature_key = f"{xgb_config['predict_video_feature_key_prefix']}{app_type}:-1"
  27. video_feature = redis_helper.get_data_from_redis(key_name=video_feature_key)
  28. video_feature = eval(video_feature)
  29. # 3. 拼接不出广告时的特征 & 预测
  30. ad_feature_0 = user_feature + video_feature + [0]
  31. # ad_0_predict = model.predict_proba(np.array([ad_feature_0]))
  32. # ad_0_predict = ad_0_predict[0][1]
  33. ad_0_predict = 0.7
  34. # 4. 拼接出广告时的特征 & 预测
  35. ad_feature_1 = user_feature + video_feature + [1]
  36. # ad_1_predict = model.predict_proba(np.array([ad_feature_1]))
  37. # ad_1_predict = ad_1_predict[0][1]
  38. ad_1_predict = 0.6
  39. # 5. 作差
  40. predict_res = ad_0_predict - ad_1_predict
  41. # 6. 获取阈值
  42. threshold_key_name = f"{xgb_config['threshold_key_prefix']}{abtest_id}:{ab_test_code}"
  43. threshold = redis_helper.get_data_from_redis(key_name=threshold_key_name)
  44. if threshold is None:
  45. threshold = 0
  46. else:
  47. threshold = float(threshold)
  48. # 7. 阈值判断
  49. if predict_res > threshold:
  50. # 大于阈值,不出广告
  51. ad_predict = 1
  52. else:
  53. # 否则,出广告
  54. ad_predict = 2
  55. result = {
  56. 'predict_tag': 'xgboost',
  57. 'ad_0_predict': ad_0_predict,
  58. 'ad_1_predict': ad_1_predict,
  59. 'predict_res': predict_res,
  60. 'threshold': threshold,
  61. 'ad_predict': ad_predict}
  62. return result