|
@@ -0,0 +1,68 @@
|
|
|
+import numpy as np
|
|
|
+import xgboost as xgb
|
|
|
+from xgboost.sklearn import XGBClassifier
|
|
|
+from utils import RedisHelper
|
|
|
+from config import set_config
|
|
|
+redis_helper = RedisHelper()
|
|
|
+config_ = set_config()
|
|
|
+# 模型加载
|
|
|
+model = XGBClassifier()
|
|
|
+booster = xgb.Booster()
|
|
|
+booster.load_model('./data/ad_xgb.model')
|
|
|
+model._Booster = booster
|
|
|
+
|
|
|
+
|
|
|
+def xgboost_predict(app_type, mid, video_id, abtest_id, ab_test_code):
|
|
|
+ xgb_config = config_.AD_MODEL_CONFIG['xgb']
|
|
|
+ # 1. 获取user特征
|
|
|
+ user_feature_key = f"{xgb_config['predict_user_feature_key_prefix']}{app_type}:{mid}"
|
|
|
+ user_feature = redis_helper.get_data_from_redis(key_name=user_feature_key)
|
|
|
+ if user_feature is None:
|
|
|
+ user_feature_key = f"{xgb_config['predict_user_feature_key_prefix']}{app_type}:-1"
|
|
|
+ user_feature = redis_helper.get_data_from_redis(key_name=user_feature_key)
|
|
|
+ user_feature = eval(user_feature)
|
|
|
+
|
|
|
+ # 2. 获取video特征
|
|
|
+ video_feature_key = f"{xgb_config['predict_video_feature_key_prefix']}{app_type}:{video_id}"
|
|
|
+ video_feature = redis_helper.get_data_from_redis(key_name=video_feature_key)
|
|
|
+ if video_feature is None:
|
|
|
+ video_feature_key = f"{xgb_config['predict_video_feature_key_prefix']}{app_type}:-1"
|
|
|
+ video_feature = redis_helper.get_data_from_redis(key_name=video_feature_key)
|
|
|
+ video_feature = eval(video_feature)
|
|
|
+
|
|
|
+ # 3. 拼接出广告时的特征 & 预测
|
|
|
+ ad_feature_0 = user_feature + video_feature + [0]
|
|
|
+ ad_0_predict = model.predict_proba(np.array([ad_feature_0]))
|
|
|
+ ad_0_predict = ad_0_predict[0][1]
|
|
|
+
|
|
|
+ # 4. 拼接不出广告时的特征 & 预测
|
|
|
+ ad_feature_1 = user_feature + video_feature + [1]
|
|
|
+ ad_1_predict = model.predict_proba(np.array([ad_feature_1]))
|
|
|
+ ad_1_predict = ad_1_predict[0][1]
|
|
|
+
|
|
|
+ # 5. 作差
|
|
|
+ predict_res = ad_0_predict - ad_1_predict
|
|
|
+
|
|
|
+ # 6. 获取阈值
|
|
|
+ threshold_key_name = f"{xgb_config['threshold_key_prefix']}{abtest_id}:{ab_test_code}"
|
|
|
+ threshold = redis_helper.get_data_from_redis(key_name=threshold_key_name)
|
|
|
+ if threshold is None:
|
|
|
+ threshold = 0
|
|
|
+ else:
|
|
|
+ threshold = float(threshold)
|
|
|
+
|
|
|
+ # 7. 阈值判断
|
|
|
+ if predict_res > threshold:
|
|
|
+ # 大于阈值,不出广告
|
|
|
+ ad_predict = 1
|
|
|
+ else:
|
|
|
+ # 否则,出广告
|
|
|
+ ad_predict = 2
|
|
|
+ result = {
|
|
|
+ 'predict_tag': 'xgboost',
|
|
|
+ 'ad_0_predict': ad_0_predict,
|
|
|
+ 'ad_1_predict': ad_1_predict,
|
|
|
+ 'predict_res': predict_res,
|
|
|
+ 'threshold': threshold,
|
|
|
+ 'ad_predict': ad_predict}
|
|
|
+ return result
|