import datetime
import sys
import traceback
import numpy as np
import pandas as pd
from odps import ODPS
from my_utils import data_check, get_feature_data, send_msg_to_feishu_new, RedisHelper
from my_config import set_config
from log import Log
config_, _ = set_config()
log_ = Log()
redis_helper = RedisHelper()


def predict_user_group_share_rate(dt, app_type):
    """预估用户组对应的有广告时分享率"""
    # 获取用户组特征
    project = config_.ad_model_data['users_share_rate'].get('project')
    table = config_.ad_model_data['users_share_rate'].get('table')
    features = [
        'apptype',
        'group',
        'sharerate_all',
        'sharerate_ad'
    ]
    user_group_df = get_feature_data(project=project, table=table, features=features, dt=dt)
    user_group_df['apptype'] = user_group_df['apptype'].astype(int)
    user_group_df = user_group_df[user_group_df['apptype'] == app_type]
    user_group_df['sharerate_all'] = user_group_df['sharerate_all'].astype(float)
    user_group_df['sharerate_ad'] = user_group_df['sharerate_ad'].astype(float)
    # 获取有广告时所有用户组近30天的分享率
    ad_all_group_share_rate = user_group_df[user_group_df['group'] == 'allmids']['sharerate_ad'].values[0]
    user_group_df = user_group_df[user_group_df['group'] != 'allmids']
    # 计算用户组有广告时分享率
    user_group_df['group_ad_share_rate'] = \
        user_group_df['sharerate_ad'] * float(ad_all_group_share_rate) / user_group_df['sharerate_all']
    return user_group_df


def predict_video_share_rate(dt, app_type):
    """预估视频有广告时分享率"""
    # 获取视频特征
    project = config_.ad_model_data['videos_share_rate'].get('project')
    table = config_.ad_model_data['videos_share_rate'].get('table')
    features = [
        'apptype',
        'videoid',
        'sharerate_all',
        'sharerate_ad'
    ]
    video_df = get_feature_data(project=project, table=table, features=features, dt=dt)
    video_df['apptype'] = video_df['apptype'].astype(int)
    video_df = video_df[video_df['apptype'] == app_type]
    video_df['sharerate_all'] = video_df['sharerate_all'].astype(float)
    video_df['sharerate_ad'] = video_df['sharerate_ad'].astype(float)
    # 获取有广告时所有视频近30天的分享率
    ad_all_videos_share_rate = video_df[video_df['videoid'] == 'allvideos']['sharerate_ad'].values[0]
    video_df = video_df[video_df['videoid'] != 'allvideos']
    # 计算视频有广告时分享率
    video_df['video_ad_share_rate'] = \
        video_df['sharerate_ad'] * float(ad_all_videos_share_rate) / video_df['sharerate_all']
    return video_df


def predict_ad_group_video(dt, config_key, config_param, threshold_record):
    log_.info(f"config_key = {config_key} update start ...")
    # 获取用户组预测值
    user_data_key = config_param['user'].get('data')
    user_rule_key = config_param['user'].get('rule')
    group_key_name = f"{config_.KEY_NAME_PREFIX_AD_GROUP}{user_data_key}:{user_rule_key}:{dt}"
    group_data = redis_helper.get_all_data_from_zset(key_name=group_key_name, with_scores=True)
    if group_data is None:
        log_.info(f"group data is None!")
    group_df = pd.DataFrame(data=group_data, columns=['group', 'group_ad_share_rate'])
    group_df = group_df[group_df['group'] != 'mean_group']
    log_.info(f"group_df count = {len(group_df)}")

    # 获取视频预测值
    video_data_key = config_param['video'].get('data')
    video_key_name = f"{config_.KEY_NAME_PREFIX_AD_VIDEO}{video_data_key}:{dt}"
    video_data = redis_helper.get_all_data_from_zset(key_name=video_key_name, with_scores=True)
    if video_data is None:
        log_.info(f"video data is None!")
    video_df = pd.DataFrame(data=video_data, columns=['videoid', 'video_ad_share_rate'])
    video_df = video_df[video_df['videoid'] != -1]
    log_.info(f"video_df count = {len(video_df)}")

    if len(group_df) == 0 or len(video_df) == 0:
        sys.exit(1)

    predict_df = video_df
    all_group_data = []
    for index, item in group_df.iterrows():
        predict_df[item['group']] = predict_df['video_ad_share_rate'] * item['group_ad_share_rate']
        all_group_data.extend(predict_df[item['group']].tolist())

    # 计算对应的阈值
    # ad_threshold_mappings = config_.AD_ABTEST_THRESHOLD_CONFIG.get(config_key.split('-')[0])
    ad_threshold_mappings = threshold_record.get(config_key.split('-')[0])
    for abtest_group, ad_threshold_mapping in ad_threshold_mappings.items():
        threshold_data = {}
        for _, item in group_df.iterrows():
            # 获取分组对应的均值作为阈值
            threshold_data[item['group']] = predict_df[item['group']].mean() * ad_threshold_mapping['group']
        threshold_data['mean_group'] = np.mean(all_group_data) * ad_threshold_mapping['mean_group']
        # 获取需要多出广告的用户组,及阈值比例
        more_ad = config_param.get('more_ad', None)
        if more_ad is not None:
            for group_key, group_threshold_rate in more_ad.items():
                threshold_data[group_key] = threshold_data[group_key] * group_threshold_rate
        log_.info(f"config_key = {config_key}, abtest_group = {abtest_group}, threshold_data = {threshold_data}")

        # 将阈值写入redis
        abtest_config_list = config_key.split('-')
        abtest_id, abtest_config_tag = abtest_config_list[0], abtest_config_list[1]
        for key, val in threshold_data.items():
            key_name = f"{config_.KEY_NAME_PREFIX_AD_THRESHOLD}{abtest_id}:{abtest_config_tag}:{abtest_group}:{key}"
            redis_helper.set_data_to_redis(key_name=key_name, value=val, expire_time=2 * 24 * 3600)

        # 计算关怀模式实验阈值 并 写入Redis
        care_model = config_param.get('care_model', None)
        threshold_rate = config_param.get('threshold_rate', None)
        if care_model is True:
            care_model_threshold_data = {}
            for key, val in threshold_data.items():
                up_val = val * threshold_rate
                care_model_threshold_data[key] = up_val
                up_key_name = \
                    f"{config_.KEY_NAME_PREFIX_AD_THRESHOLD_CARE_MODEL}{abtest_id}:{abtest_config_tag}:{abtest_group}:{key}"
                redis_helper.set_data_to_redis(key_name=up_key_name, value=up_val, expire_time=2 * 24 * 3600)
            log_.info(f"config_key = {config_key}, abtest_group = {abtest_group}, "
                      f"care_model_threshold_data = {care_model_threshold_data}")

    # predict_df.to_csv(f'./data/ad_user_video_predict_{config_key}.csv')
    log_.info(f"config_key = {config_key} update end!")


def predict_ad_group_video_mix_with_add(dt, config_key, config_param, threshold_record):
    log_.info(f"config_key = {config_key} update start ...")
    # ###### 获取以分享为目标的数据
    # 获取用户组预测值(出广告后分享的概率)
    share_user_data_key = config_param['share']['user'].get('data')
    share_user_rule_key = config_param['share']['user'].get('rule')
    share_group_key_name = f"{config_.KEY_NAME_PREFIX_AD_GROUP}{share_user_data_key}:{share_user_rule_key}:{dt}"
    share_group_data = redis_helper.get_all_data_from_zset(key_name=share_group_key_name, with_scores=True)
    if share_group_data is None:
        log_.info(f"share group data is None!")
    share_group_df = pd.DataFrame(data=share_group_data, columns=['group', 'group_ad_share_rate'])
    share_group_df = share_group_df[share_group_df['group'] != 'mean_group']
    log_.info(f"share_group_df count = {len(share_group_df)}")

    # 获取视频预测值(出广告后不分享的概率)
    share_video_data_key = config_param['share']['video'].get('data')
    share_video_key_name = f"{config_.KEY_NAME_PREFIX_AD_VIDEO}{share_video_data_key}:{dt}"
    share_video_data = redis_helper.get_all_data_from_zset(key_name=share_video_key_name, with_scores=True)
    if share_video_data is None:
        log_.info(f"share video data is None!")
    share_video_df = pd.DataFrame(data=share_video_data, columns=['videoid', 'video_ad_share_rate'])
    share_video_df = share_video_df[share_video_df['videoid'] != -1]
    log_.info(f"share_video_df count = {len(share_video_df)}")

    if len(share_video_df) == 0 or len(share_video_df) == 0:
        sys.exit(1)

    # ###### 获取以不直接跳出为目标的数据
    # 获取用户组预测值(出广告后不直接跳出的概率)
    out_user_data_key = config_param['out']['user'].get('data')
    out_user_rule_key = config_param['out']['user'].get('rule')
    out_group_key_name = f"{config_.KEY_NAME_PREFIX_AD_GROUP}{out_user_data_key}:{out_user_rule_key}:{dt}"
    out_group_data = redis_helper.get_all_data_from_zset(key_name=out_group_key_name, with_scores=True)
    if out_group_data is None:
        log_.info(f"out group data is None!")
    out_group_df = pd.DataFrame(data=out_group_data, columns=['group', 'group_ad_not_out_rate'])
    out_group_df = out_group_df[out_group_df['group'] != 'mean_group']
    log_.info(f"out_group_df count = {len(out_group_df)}")

    # 获取视频预测值(出广告后不直接跳出的概率)
    out_video_data_key = config_param['out']['video'].get('data')
    out_video_key_name = f"{config_.KEY_NAME_PREFIX_AD_VIDEO}{out_video_data_key}:{dt}"
    out_video_data = redis_helper.get_all_data_from_zset(key_name=out_video_key_name, with_scores=True)
    if out_video_data is None:
        log_.info(f"out video data is None!")
    out_video_df = pd.DataFrame(data=out_video_data, columns=['videoid', 'video_ad_not_out_rate'])
    out_video_df = out_video_df[out_video_df['videoid'] != -1]
    log_.info(f"out_video_df count = {len(out_video_df)}")

    if len(share_video_df) == 0 or len(share_video_df) == 0:
        sys.exit(1)

    # 加权融合
    share_weight = config_param['mix_param']['share_weight']
    out_weight = config_param['mix_param']['out_weight']
    # 用户侧数据
    group_df = pd.merge(share_group_df, out_group_df, on='group')
    group_df['group_rate'] = \
        share_weight * group_df['group_ad_share_rate'] + out_weight * group_df['group_ad_not_out_rate']
    # 视频侧数据
    video_df = pd.merge(share_video_df, out_video_df, on='videoid')
    video_df['video_rate'] = \
        share_weight * video_df['video_ad_share_rate'] + out_weight * video_df['video_ad_not_out_rate']

    predict_df = video_df.copy()
    all_group_data = []
    for index, item in group_df.iterrows():
        predict_df[item['group']] = predict_df['video_rate'] * item['group_rate']
        all_group_data.extend(predict_df[item['group']].tolist())

    # 计算对应的阈值
    ad_threshold_mappings = threshold_record.get(config_key.split('-')[0])
    for abtest_group, ad_threshold_mapping in ad_threshold_mappings.items():
        threshold_data = {}
        for _, item in group_df.iterrows():
            # 获取分组对应的均值作为阈值
            threshold_data[item['group']] = predict_df[item['group']].mean() * ad_threshold_mapping['group']
        threshold_data['mean_group'] = np.mean(all_group_data) * ad_threshold_mapping['mean_group']
        # 获取需要多出广告的用户组,及阈值比例
        more_ad = config_param.get('more_ad', None)
        if more_ad is not None:
            for group_key, group_threshold_rate in more_ad.items():
                threshold_data[group_key] = threshold_data[group_key] * group_threshold_rate
        log_.info(f"config_key = {config_key}, abtest_group = {abtest_group}, threshold_data = {threshold_data}")

        # 将阈值写入redis
        abtest_config_list = config_key.split('-')
        abtest_id, abtest_config_tag = abtest_config_list[0], abtest_config_list[1]
        for key, val in threshold_data.items():
            key_name = f"{config_.KEY_NAME_PREFIX_AD_THRESHOLD}{abtest_id}:{abtest_config_tag}:{abtest_group}:{key}"

            redis_helper.set_data_to_redis(key_name=key_name, value=val, expire_time=2 * 24 * 3600)

        # 计算关怀模式实验阈值 并 写入Redis
        care_model = config_param.get('care_model', None)
        threshold_rate = config_param.get('threshold_rate', None)
        if care_model is True:
            care_model_threshold_data = {}
            for key, val in threshold_data.items():
                up_val = val * threshold_rate
                care_model_threshold_data[key] = up_val
                up_key_name = \
                    f"{config_.KEY_NAME_PREFIX_AD_THRESHOLD_CARE_MODEL}{abtest_id}:{abtest_config_tag}:{abtest_group}:{key}"
                redis_helper.set_data_to_redis(key_name=up_key_name, value=up_val, expire_time=2 * 24 * 3600)
            log_.info(f"config_key = {config_key}, abtest_group = {abtest_group}, "
                      f"care_model_threshold_data = {care_model_threshold_data}")

    # predict_df.to_csv(f'./data/ad_user_video_predict_{config_key}.csv')
    log_.info(f"config_key = {config_key} update end!")


def predict_ad_group_video_mix_with_multiply(dt, config_key, config_param, threshold_record):
    log_.info(f"config_key = {config_key} update start ...")
    # ###### 获取以分享为目标的数据
    # 获取用户组预测值(出广告后分享的概率)
    share_user_data_key = config_param['share']['user'].get('data')
    share_user_rule_key = config_param['share']['user'].get('rule')
    share_group_key_name = f"{config_.KEY_NAME_PREFIX_AD_GROUP}{share_user_data_key}:{share_user_rule_key}:{dt}"
    share_group_data = redis_helper.get_all_data_from_zset(key_name=share_group_key_name, with_scores=True)
    if share_group_data is None:
        log_.info(f"share group data is None!")
    share_group_df = pd.DataFrame(data=share_group_data, columns=['group', 'group_ad_share_rate'])
    share_group_df = share_group_df[share_group_df['group'] != 'mean_group']
    log_.info(f"share_group_df count = {len(share_group_df)}")

    # 获取视频预测值(出广告后分享的概率)
    share_video_data_key = config_param['share']['video'].get('data')
    share_video_key_name = f"{config_.KEY_NAME_PREFIX_AD_VIDEO}{share_video_data_key}:{dt}"
    share_video_data = redis_helper.get_all_data_from_zset(key_name=share_video_key_name, with_scores=True)
    if share_video_data is None:
        log_.info(f"share video data is None!")
    share_video_df = pd.DataFrame(data=share_video_data, columns=['videoid', 'video_ad_share_rate'])
    share_video_df = share_video_df[share_video_df['videoid'] != -1]
    log_.info(f"share_video_df count = {len(share_video_df)}")

    if len(share_video_df) == 0 or len(share_video_df) == 0:
        sys.exit(1)

    # ###### 获取以不直接跳出为目标的数据
    # 获取用户组预测值(出广告后不直接跳出的概率)
    out_user_data_key = config_param['out']['user'].get('data')
    out_user_rule_key = config_param['out']['user'].get('rule')
    out_group_key_name = f"{config_.KEY_NAME_PREFIX_AD_GROUP}{out_user_data_key}:{out_user_rule_key}:{dt}"
    out_group_data = redis_helper.get_all_data_from_zset(key_name=out_group_key_name, with_scores=True)
    if out_group_data is None:
        log_.info(f"out group data is None!")
    out_group_df = pd.DataFrame(data=out_group_data, columns=['group', 'group_ad_not_out_rate'])
    out_group_df = out_group_df[out_group_df['group'] != 'mean_group']
    log_.info(f"out_group_df count = {len(out_group_df)}")

    # 获取视频预测值(出广告后不直接跳出的概率)
    out_video_data_key = config_param['out']['video'].get('data')
    out_video_key_name = f"{config_.KEY_NAME_PREFIX_AD_VIDEO}{out_video_data_key}:{dt}"
    out_video_data = redis_helper.get_all_data_from_zset(key_name=out_video_key_name, with_scores=True)
    if out_video_data is None:
        log_.info(f"out video data is None!")
    out_video_df = pd.DataFrame(data=out_video_data, columns=['videoid', 'video_ad_not_out_rate'])
    out_video_df = out_video_df[out_video_df['videoid'] != -1]
    log_.info(f"out_video_df count = {len(out_video_df)}")

    if len(share_video_df) == 0 or len(share_video_df) == 0:
        sys.exit(1)

    # 乘积融合
    # 用户侧数据
    group_df = pd.merge(share_group_df, out_group_df, on='group')
    group_df['group_rate'] = group_df['group_ad_share_rate'] * group_df['group_ad_not_out_rate']
    # 视频侧数据
    video_df = pd.merge(share_video_df, out_video_df, on='videoid')
    video_df['video_rate'] = video_df['video_ad_share_rate'] * video_df['video_ad_not_out_rate']

    predict_df = video_df.copy()
    all_group_data = []
    for index, item in group_df.iterrows():
        predict_df[item['group']] = predict_df['video_rate'] * item['group_rate']
        all_group_data.extend(predict_df[item['group']].tolist())

    # 计算对应的阈值
    ad_threshold_mappings = threshold_record.get(config_key.split('-')[0])
    for abtest_group, ad_threshold_mapping in ad_threshold_mappings.items():
        threshold_data = {}
        for _, item in group_df.iterrows():
            # 获取分组对应的均值作为阈值
            threshold_data[item['group']] = predict_df[item['group']].mean() * ad_threshold_mapping['group']
        threshold_data['mean_group'] = np.mean(all_group_data) * ad_threshold_mapping['mean_group']
        # 获取需要多出广告的用户组,及阈值比例
        more_ad = config_param.get('more_ad', None)
        if more_ad is not None:
            for group_key, group_threshold_rate in more_ad.items():
                threshold_data[group_key] = threshold_data[group_key] * group_threshold_rate
        log_.info(f"config_key = {config_key}, abtest_group = {abtest_group}, threshold_data = {threshold_data}")

        # 将阈值写入redis
        abtest_config_list = config_key.split('-')
        abtest_id, abtest_config_tag = abtest_config_list[0], abtest_config_list[1]
        for key, val in threshold_data.items():
            key_name = f"{config_.KEY_NAME_PREFIX_AD_THRESHOLD}{abtest_id}:{abtest_config_tag}:{abtest_group}:{key}"
            redis_helper.set_data_to_redis(key_name=key_name, value=val, expire_time=2 * 24 * 3600)

        # 计算关怀模式实验阈值 并 写入Redis
        care_model = config_param.get('care_model', None)
        threshold_rate = config_param.get('threshold_rate', None)
        if care_model is True:
            care_model_threshold_data = {}
            for key, val in threshold_data.items():
                up_val = val * threshold_rate
                care_model_threshold_data[key] = up_val
                up_key_name = \
                    f"{config_.KEY_NAME_PREFIX_AD_THRESHOLD_CARE_MODEL}{abtest_id}:{abtest_config_tag}:{abtest_group}:{key}"
                redis_helper.set_data_to_redis(key_name=up_key_name, value=up_val, expire_time=2 * 24 * 3600)
            log_.info(f"config_key = {config_key}, abtest_group = {abtest_group}, "
                      f"care_model_threshold_data = {care_model_threshold_data}")

    # predict_df.to_csv(f'./data/ad_user_video_predict_{config_key}.csv')
    log_.info(f"config_key = {config_key} update end!")


def predict():
    try:
        now_date = datetime.datetime.today()
        dt = datetime.datetime.strftime(now_date, '%Y%m%d')
        log_.info(f"dt = {dt}")
        # 获取阈值参数记录
        threshold_record = redis_helper.get_data_from_redis(key_name=config_.KEY_NAME_PREFIX_AD_THRESHOLD_RECORD)
        # print(threshold_record)
        threshold_record = eval(threshold_record)
        log_.info(f"threshold_record = {threshold_record}")
        params = config_.AD_ABTEST_CONFIG
        for config_key, config_param in params.items():
            if config_param.get('threshold_mix_func') == 'add':
                predict_ad_group_video_mix_with_add(dt=dt,
                                                    config_key=config_key,
                                                    config_param=config_param,
                                                    threshold_record=threshold_record)
            elif config_param.get('threshold_mix_func') == 'multiply':
                predict_ad_group_video_mix_with_multiply(dt=dt,
                                                         config_key=config_key,
                                                         config_param=config_param,
                                                         threshold_record=threshold_record)
            else:
                predict_ad_group_video(dt=dt,
                                       config_key=config_key,
                                       config_param=config_param,
                                       threshold_record=threshold_record)
        # 阈值参数记录
        # redis_helper.set_data_to_redis(key_name=config_.KEY_NAME_PREFIX_AD_THRESHOLD_RECORD,
        #                                value=str(config_.AD_ABTEST_THRESHOLD_CONFIG),
        #                                expire_time=24*3600)
        redis_helper.set_data_to_redis(key_name=config_.KEY_NAME_PREFIX_AD_THRESHOLD_RECORD,
                                       value=str(threshold_record),
                                       expire_time=2 * 24 * 3600)
        msg_list = [
            f"env: rov-offline {config_.ENV_TEXT}",
            f"finished time: {datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d %H:%M:%S')}",
        ]
        send_msg_to_feishu_new(
            webhook=config_.FEISHU_ROBOT['ad_threshold_update_robot'].get('webhook'),
            key_word=config_.FEISHU_ROBOT['ad_threshold_update_robot'].get('key_word'),
            title='广告模型阈值更新完成',
            msg_list=msg_list
        )
    except Exception as e:
        log_.error(f"广告模型阈值更新失败, exception: {e}, traceback: {traceback.format_exc()}")
        msg_list = [
            f"env: rov-offline {config_.ENV_TEXT}",
            f"now time: {datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d %H:%M:%S')}",
            f"exception: {e}",
            f"traceback: {traceback.format_exc()}",
        ]
        send_msg_to_feishu_new(
            webhook=config_.FEISHU_ROBOT['ad_threshold_update_robot'].get('webhook'),
            key_word=config_.FEISHU_ROBOT['ad_threshold_update_robot'].get('key_word'),
            title='广告模型阈值更新失败',
            msg_list=msg_list
        )


if __name__ == '__main__':
    # predict_ad_group_video()
    predict()