Explorar o código

add ad abtest: xgboost model

liqian hai 1 ano
pai
achega
c2c76de012
Modificáronse 4 ficheiros con 90 adicións e 2 borrados
  1. 7 2
      ad_recommend.py
  2. 68 0
      ad_xgboost_predict.py
  3. 15 0
      config.py
  4. BIN=BIN
      data/ad_xgb.model

+ 7 - 2
ad_recommend.py

@@ -1,6 +1,7 @@
 import traceback
 import datetime
 from utils import RedisHelper
+from ad_xgboost_predict import xgboost_predict
 from config import set_config
 from log import Log
 log_ = Log()
@@ -362,9 +363,13 @@ def ad_recommend_predict(app_type, mid, video_id, ab_exp_info, ab_test_code, car
         abtest_param = config_.AD_ABTEST_CONFIG.get(f'{abtest_id}-{abtest_config_tag}')
         if abtest_param is None:
             return None
-
+        predict_model = abtest_param.get('predict_model', None)
         threshold_mix_func = abtest_param.get('threshold_mix_func', None)
-        if threshold_mix_func == 'add':
+        if predict_model == 'xgb':
+            result = xgboost_predict(
+                app_type=app_type, mid=mid, video_id=video_id, abtest_id=abtest_id, ab_test_code=ab_test_code
+            )
+        elif threshold_mix_func == 'add':
             result = predict_mid_video_res_with_add(
                 now_date=now_date,
                 mid=mid,

+ 68 - 0
ad_xgboost_predict.py

@@ -0,0 +1,68 @@
+import numpy as np
+import xgboost as xgb
+from xgboost.sklearn import XGBClassifier
+from utils import RedisHelper
+from config import set_config
+redis_helper = RedisHelper()
+config_ = set_config()
+# 模型加载
+model = XGBClassifier()
+booster = xgb.Booster()
+booster.load_model('./data/ad_xgb.model')
+model._Booster = booster
+
+
+def xgboost_predict(app_type, mid, video_id, abtest_id, ab_test_code):
+    xgb_config = config_.AD_MODEL_CONFIG['xgb']
+    # 1. 获取user特征
+    user_feature_key = f"{xgb_config['predict_user_feature_key_prefix']}{app_type}:{mid}"
+    user_feature = redis_helper.get_data_from_redis(key_name=user_feature_key)
+    if user_feature is None:
+        user_feature_key = f"{xgb_config['predict_user_feature_key_prefix']}{app_type}:-1"
+        user_feature = redis_helper.get_data_from_redis(key_name=user_feature_key)
+    user_feature = eval(user_feature)
+
+    # 2. 获取video特征
+    video_feature_key = f"{xgb_config['predict_video_feature_key_prefix']}{app_type}:{video_id}"
+    video_feature = redis_helper.get_data_from_redis(key_name=video_feature_key)
+    if video_feature is None:
+        video_feature_key = f"{xgb_config['predict_video_feature_key_prefix']}{app_type}:-1"
+        video_feature = redis_helper.get_data_from_redis(key_name=video_feature_key)
+    video_feature = eval(video_feature)
+
+    # 3. 拼接出广告时的特征 & 预测
+    ad_feature_0 = user_feature + video_feature + [0]
+    ad_0_predict = model.predict_proba(np.array([ad_feature_0]))
+    ad_0_predict = ad_0_predict[0][1]
+
+    # 4. 拼接不出广告时的特征 & 预测
+    ad_feature_1 = user_feature + video_feature + [1]
+    ad_1_predict = model.predict_proba(np.array([ad_feature_1]))
+    ad_1_predict = ad_1_predict[0][1]
+
+    # 5. 作差
+    predict_res = ad_0_predict - ad_1_predict
+
+    # 6. 获取阈值
+    threshold_key_name = f"{xgb_config['threshold_key_prefix']}{abtest_id}:{ab_test_code}"
+    threshold = redis_helper.get_data_from_redis(key_name=threshold_key_name)
+    if threshold is None:
+        threshold = 0
+    else:
+        threshold = float(threshold)
+
+    # 7. 阈值判断
+    if predict_res > threshold:
+        # 大于阈值,不出广告
+        ad_predict = 1
+    else:
+        # 否则,出广告
+        ad_predict = 2
+    result = {
+        'predict_tag': 'xgboost',
+        'ad_0_predict': ad_0_predict,
+        'ad_1_predict': ad_1_predict,
+        'predict_res': predict_res,
+        'threshold': threshold,
+        'ad_predict': ad_predict}
+    return result

+ 15 - 0
config.py

@@ -1221,6 +1221,9 @@ class BaseConfig(object):
             'care_model_ab_mid_group': ['mean_group'],
             'threshold_mix_func': 'multiply',
         },  # 所有广告类型本端视频数据 + 优化阈值计算方式 + else非关怀模式人群多出广告 + 分享与不直接跳出融合方案二(乘积融合: p(不直接跳出|出广告) * p(分享|出广告))
+        '195-xgb': {
+            'predict_model': 'xgb'
+        },  # xgboost在线预测
 
         # 票圈短视频
         # '196-a': {
@@ -1631,6 +1634,18 @@ class BaseConfig(object):
     # 广告推荐关怀模式实验阈值结果存放 redis key 前缀,完整格式:ad:threshold:care:{abtestId}:{abtestConfigTag}:{abtestGroup}:{group}
     KEY_NAME_PREFIX_AD_THRESHOLD_CARE_MODEL = 'ad:threshold:care:'
 
+    # 广告模型在线预测配置 - 模型
+    AD_MODEL_CONFIG = {
+        'xgb': {
+            # 视频特征存放 redis key 前缀,完整格式:ad:xgb:predict:video:{app_type}:{video_id}
+            'predict_video_feature_key_prefix': 'ad:xgb:predict:video:',
+            # 用户特征存放 redis key 前缀,完整格式:ad:xgb:predict:user:{app_type}:{mid}
+            'predict_user_feature_key_prefix': 'ad:xgb:predict:user:',
+            # 阈值结果存放 redis key 前缀,完整格式:ad:xgb:predict:{abtestId}:{abtestGroup}
+            'threshold_key_prefix': 'ad:xgb:threshold:'
+        }
+    }
+
 
 class DevelopmentConfig(BaseConfig):
     """开发环境配置"""

BIN=BIN
data/ad_xgb.model