1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- import numpy as np
- import xgboost as xgb
- from xgboost.sklearn import XGBClassifier
- from utils import RedisHelper
- from config import set_config
- redis_helper = RedisHelper()
- config_ = set_config()
- # 模型加载
- model = XGBClassifier()
- booster = xgb.Booster()
- booster.load_model('./data/ad_xgb.model')
- model._Booster = booster
- def xgboost_predict(app_type, mid, video_id, abtest_id, ab_test_code):
- xgb_config = config_.AD_MODEL_CONFIG['xgb']
- # 1. 获取user特征
- user_feature_key = f"{xgb_config['predict_user_feature_key_prefix']}{app_type}:{mid}"
- user_feature = redis_helper.get_data_from_redis(key_name=user_feature_key)
- if user_feature is None:
- user_feature_key = f"{xgb_config['predict_user_feature_key_prefix']}{app_type}:-1"
- user_feature = redis_helper.get_data_from_redis(key_name=user_feature_key)
- user_feature = eval(user_feature)
- # 2. 获取video特征
- video_feature_key = f"{xgb_config['predict_video_feature_key_prefix']}{app_type}:{video_id}"
- video_feature = redis_helper.get_data_from_redis(key_name=video_feature_key)
- if video_feature is None:
- video_feature_key = f"{xgb_config['predict_video_feature_key_prefix']}{app_type}:-1"
- video_feature = redis_helper.get_data_from_redis(key_name=video_feature_key)
- video_feature = eval(video_feature)
- # 3. 拼接出广告时的特征 & 预测
- ad_feature_0 = user_feature + video_feature + [0]
- ad_0_predict = model.predict_proba(np.array([ad_feature_0]))
- ad_0_predict = ad_0_predict[0][1]
- # 4. 拼接不出广告时的特征 & 预测
- ad_feature_1 = user_feature + video_feature + [1]
- ad_1_predict = model.predict_proba(np.array([ad_feature_1]))
- ad_1_predict = ad_1_predict[0][1]
- # 5. 作差
- predict_res = ad_0_predict - ad_1_predict
- # 6. 获取阈值
- threshold_key_name = f"{xgb_config['threshold_key_prefix']}{abtest_id}:{ab_test_code}"
- threshold = redis_helper.get_data_from_redis(key_name=threshold_key_name)
- if threshold is None:
- threshold = 0
- else:
- threshold = float(threshold)
- # 7. 阈值判断
- if predict_res > threshold:
- # 大于阈值,不出广告
- ad_predict = 1
- else:
- # 否则,出广告
- ad_predict = 2
- result = {
- 'predict_tag': 'xgboost',
- 'ad_0_predict': ad_0_predict,
- 'ad_1_predict': ad_1_predict,
- 'predict_res': predict_res,
- 'threshold': threshold,
- 'ad_predict': ad_predict}
- return result
|