import numpy as np import xgboost as xgb from xgboost.sklearn import XGBClassifier from utils import RedisHelper from config import set_config redis_helper = RedisHelper() config_ = set_config() # # 模型加载 # model = XGBClassifier() # booster = xgb.Booster() # booster.load_model('./data/ad_xgb.model') # model._Booster = booster def xgboost_predict(model, app_type, mid, video_id, abtest_id, ab_test_code): xgb_config = config_.AD_MODEL_CONFIG['xgb'] # 1. 获取user特征 user_feature_key = f"{xgb_config['predict_user_feature_key_prefix']}{app_type}:{mid}" user_feature = redis_helper.get_data_from_redis(key_name=user_feature_key) if user_feature is None: user_feature_key = f"{xgb_config['predict_user_feature_key_prefix']}{app_type}:-1" user_feature = redis_helper.get_data_from_redis(key_name=user_feature_key) user_feature = eval(user_feature) # 2. 获取video特征 video_feature_key = f"{xgb_config['predict_video_feature_key_prefix']}{app_type}:{video_id}" video_feature = redis_helper.get_data_from_redis(key_name=video_feature_key) if video_feature is None: video_feature_key = f"{xgb_config['predict_video_feature_key_prefix']}{app_type}:-1" video_feature = redis_helper.get_data_from_redis(key_name=video_feature_key) video_feature = eval(video_feature) # 3. 拼接出广告时的特征 & 预测 ad_feature_0 = user_feature + video_feature + [0] ad_0_predict = model.predict_proba(np.array([ad_feature_0])) ad_0_predict = ad_0_predict[0][1] # 4. 拼接不出广告时的特征 & 预测 ad_feature_1 = user_feature + video_feature + [1] ad_1_predict = model.predict_proba(np.array([ad_feature_1])) ad_1_predict = ad_1_predict[0][1] # 5. 作差 predict_res = ad_0_predict - ad_1_predict # 6. 获取阈值 threshold_key_name = f"{xgb_config['threshold_key_prefix']}{abtest_id}:{ab_test_code}" threshold = redis_helper.get_data_from_redis(key_name=threshold_key_name) if threshold is None: threshold = 0 else: threshold = float(threshold) # 7. 阈值判断 if predict_res > threshold: # 大于阈值,不出广告 ad_predict = 1 else: # 否则,出广告 ad_predict = 2 result = { 'predict_tag': 'xgboost', 'ad_0_predict': ad_0_predict, 'ad_1_predict': ad_1_predict, 'predict_res': predict_res, 'threshold': threshold, 'ad_predict': ad_predict} return result