import numpy as np
import xgboost as xgb
from xgboost.sklearn import XGBClassifier
from utils import RedisHelper
from config import set_config
redis_helper = RedisHelper()
config_ = set_config()
# 模型加载
model = XGBClassifier()
booster = xgb.Booster()
booster.load_model('./data/ad_xgb.model')
model._Booster = booster


def xgboost_predict(app_type, mid, video_id, abtest_id, ab_test_code):
    xgb_config = config_.AD_MODEL_CONFIG['xgb']
    # 1. 获取user特征
    user_feature_key = f"{xgb_config['predict_user_feature_key_prefix']}{app_type}:{mid}"
    user_feature = redis_helper.get_data_from_redis(key_name=user_feature_key)
    if user_feature is None:
        user_feature_key = f"{xgb_config['predict_user_feature_key_prefix']}{app_type}:-1"
        user_feature = redis_helper.get_data_from_redis(key_name=user_feature_key)
    user_feature = eval(user_feature)

    # 2. 获取video特征
    video_feature_key = f"{xgb_config['predict_video_feature_key_prefix']}{app_type}:{video_id}"
    video_feature = redis_helper.get_data_from_redis(key_name=video_feature_key)
    if video_feature is None:
        video_feature_key = f"{xgb_config['predict_video_feature_key_prefix']}{app_type}:-1"
        video_feature = redis_helper.get_data_from_redis(key_name=video_feature_key)
    video_feature = eval(video_feature)

    # 3. 拼接出广告时的特征 & 预测
    ad_feature_0 = user_feature + video_feature + [0]
    ad_0_predict = model.predict_proba(np.array([ad_feature_0]))
    ad_0_predict = ad_0_predict[0][1]

    # 4. 拼接不出广告时的特征 & 预测
    ad_feature_1 = user_feature + video_feature + [1]
    ad_1_predict = model.predict_proba(np.array([ad_feature_1]))
    ad_1_predict = ad_1_predict[0][1]

    # 5. 作差
    predict_res = ad_0_predict - ad_1_predict

    # 6. 获取阈值
    threshold_key_name = f"{xgb_config['threshold_key_prefix']}{abtest_id}:{ab_test_code}"
    threshold = redis_helper.get_data_from_redis(key_name=threshold_key_name)
    if threshold is None:
        threshold = 0
    else:
        threshold = float(threshold)

    # 7. 阈值判断
    if predict_res > threshold:
        # 大于阈值,不出广告
        ad_predict = 1
    else:
        # 否则,出广告
        ad_predict = 2
    result = {
        'predict_tag': 'xgboost',
        'ad_0_predict': ad_0_predict,
        'ad_1_predict': ad_1_predict,
        'predict_res': predict_res,
        'threshold': threshold,
        'ad_predict': ad_predict}
    return result