# -*- coding: utf-8 -*-

import pandas as pd
import traceback
import odps
from odps import ODPS
import json
from threading import Timer
from datetime import datetime, timedelta
from db_helper import MysqlHelper
from my_utils import check_table_partition_exits_v2, get_dataframe_from_odps, \
    get_odps_df_of_max_partition, get_odps_instance, get_odps_df_of_recent_partitions
from my_utils import request_post, send_msg_to_feishu
from my_config import set_config
import numpy as np
from log import Log
import os
from argparse import ArgumentParser
from constants import AutoReplyAccountType

CONFIG, _ = set_config()
LOGGER = Log()

BASE_GROUP_NAME = '3rd-party-base'
EXPLORE1_GROUP_NAME = '3rd-party-explore1'
EXPLORE2_GROUP_NAME = '3rd-party-explore2'
# GH_IDS will be updated by get_and_update_gh_ids
GH_IDS = ('default',)

pd.set_option('display.max_rows', None)

TARGET_GH_IDS = (
    'gh_250c51d5ce69',
    'gh_8a29eebc2012',
    'gh_ff16c412ab97',
    'gh_1014734791e0',
    'gh_570967881eae',
    'gh_a7c21403c493',
    'gh_7f062810b4e7',
    'gh_c8060587e6d1',
    'gh_1da8f62f4a0d',
    'gh_56b65b7d4520',
    'gh_eeec7c2e28a5',
    'gh_7c89d5a3e745',
    'gh_ee5b4b07ed8b',
    'gh_0d3c97cc30cc',
    'gh_c783350a9660',
)

CDN_IMG_OPERATOR = "?x-oss-process=image/resize,m_fill,w_600,h_480,limit_0/format,jpg/watermark,image_eXNoL3BpYy93YXRlcm1hcmtlci9pY29uX3BsYXlfd2hpdGUucG5nP3gtb3NzLXByb2Nlc3M9aW1hZ2UvcmVzaXplLHdfMTQ0,g_center"

ODS_PROJECT = "loghubods"
EXPLORE_POOL_TABLE = 'alg_growth_video_return_stats_history'
GH_REPLY_STATS_TABLE = 'alg_growth_3rd_gh_reply_video_stats'
ODPS_RANK_RESULT_TABLE = 'alg_3rd_gh_autoreply_video_rank_data'
GH_DETAIL = 'gh_detail'
RDS_RANK_RESULT_TABLE = 'alg_gh_autoreply_video_rank_data'
STATS_PERIOD_DAYS = 5
SEND_N = 1

def get_and_update_gh_ids(run_dt):
    gh = get_odps_df_of_max_partition(ODS_PROJECT, GH_DETAIL, {'dt': run_dt})
    gh = gh.to_pandas()
    gh = gh[gh['type'] == AutoReplyAccountType.EXTERNAL_GZH.value]
    # default单独处理
    if 'default' not in gh['gh_id'].values:
        new_row = pd.DataFrame({'gh_id': ['default'], 'gh_name': ['默认'], 'type': [2], 'category1': ['泛生活']},
                               index=[0])
        gh = pd.concat([gh, new_row], ignore_index=True)

    gh = gh.drop_duplicates(subset=['gh_id'])
    global GH_IDS
    GH_IDS = tuple(gh['gh_id'])
    return gh


def check_data_partition(project, table, data_dt, data_hr=None):
    """检查数据是否准备好"""
    try:
        partition_spec = {'dt': data_dt}
        if data_hr:
            partition_spec['hour'] = data_hr
        part_exist, data_count = check_table_partition_exits_v2(
            project, table, partition_spec)
    except Exception as e:
        data_count = 0
    return data_count


def get_last_strategy_result(project, rank_table, dt_version, key):
    strategy_df = get_odps_df_of_max_partition(
        project, rank_table, {'ctime': dt_version}
    ).to_pandas()
    sub_df = strategy_df.query(f'strategy_key == "{key}"')
    sub_df = sub_df[['gh_id', 'video_id', 'strategy_key', 'sort']].drop_duplicates()
    return sub_df


def process_reply_stats(project, table, period, run_dt):
    # 获取多天即转统计数据用于聚合
    df = get_odps_df_of_recent_partitions(project, table, period, {'dt': run_dt})
    df = df.to_pandas()

    df['video_id'] = df['video_id'].astype('int64')
    df = df[['gh_id', 'video_id', 'send_count', 'first_visit_uv', 'day0_return']]

    # 账号内聚合
    df = df.groupby(['video_id', 'gh_id']).agg({
        'send_count': 'sum',
        'first_visit_uv': 'sum',
        'day0_return': 'sum'
    }).reset_index()

    # 聚合所有数据作为default
    default_stats_df = df.groupby('video_id').agg({
        'send_count': 'sum',
        'first_visit_uv': 'sum',
        'day0_return': 'sum'
    }).reset_index()
    default_stats_df['gh_id'] = 'default'

    merged_df = pd.concat([df, default_stats_df]).reset_index(drop=True)

    merged_df['score'] = merged_df['day0_return'] / (merged_df['send_count'] + 500)
    return merged_df


def rank_for_layer1(run_dt, run_hour, project, table, gh):
    # TODO: 加审核&退场
    df = get_odps_df_of_max_partition(project, table, {'dt': run_dt})
    df = df.to_pandas()
    # 确保重跑时可获得一致结果
    dt_version = f'{run_dt}{run_hour}'
    np.random.seed(int(dt_version) + 1)

    # TODO: 修改权重计算策略
    df['score'] = df['ros']
    # 按照 category1 分类后进行加权随机抽样
    sampled_df = df.groupby('category1').apply(
        lambda x: x.sample(n=SEND_N, weights=x['score'], replace=False)).reset_index(drop=True)
    sampled_df['sort'] = sampled_df.groupby('category1')['score'].rank(method='first', ascending=False).astype(int)
    # 按得分排序
    sampled_df = sampled_df.sort_values(by=['category1', 'score'], ascending=[True, False]).reset_index(drop=True)
    sampled_df['strategy_key'] = EXPLORE1_GROUP_NAME
    sampled_df['dt_version'] = dt_version
    extend_df = sampled_df.merge(gh, on='category1')
    result_df = extend_df[['strategy_key', 'dt_version', 'gh_id', 'sort', 'video_id', 'score']]
    return result_df


def rank_for_layer2(run_dt, run_hour, project, stats_table, rank_table):
    stats_df = process_reply_stats(project, stats_table, STATS_PERIOD_DAYS, run_dt)

    # 确保重跑时可获得一致结果
    dt_version = f'{run_dt}{run_hour}'
    np.random.seed(int(dt_version) + 1)
    # TODO: 计算账号间相关性
    ## 账号两两组合,取有RoVn数值视频的交集,单个账号内的RoVn(平滑后)组成向量
    ## 求向量相关系数或cosine相似度
    ## 单个视频的RoVn加权求和
    # 当前实现基础版本:只在账号内求二级探索排序分

    sampled_dfs = []
    # 处理default逻辑(default-explore2)
    default_stats_df = stats_df.query('gh_id == "default"')
    sampled_df = default_stats_df.sample(n=SEND_N, weights=default_stats_df['score'])
    sampled_df['sort'] = range(1, len(sampled_df) + 1)
    sampled_dfs.append(sampled_df)

    # 基础过滤for账号
    df = stats_df.query('day0_return > 100')

    # fallback to base if necessary
    base_strategy_df = get_last_strategy_result(
        project, rank_table, dt_version, BASE_GROUP_NAME)

    for gh_id in GH_IDS:
        if gh_id == 'default':
            continue
        sub_df = df.query(f'gh_id == "{gh_id}"')
        if len(sub_df) < SEND_N:
            LOGGER.warning(
                "gh_id[{}] rows[{}] not enough for layer2, fallback to base"
                .format(gh_id, len(sub_df)))
            sub_df = base_strategy_df.query(f'gh_id == "{gh_id}"')
            sub_df['score'] = sub_df['sort']
        sampled_df = sub_df.sample(n=SEND_N, weights=sub_df['score'])
        sampled_df['sort'] = range(1, len(sampled_df) + 1)
        sampled_dfs.append(sampled_df)

    extend_df = pd.concat(sampled_dfs)
    extend_df['strategy_key'] = EXPLORE2_GROUP_NAME
    extend_df['dt_version'] = dt_version
    result_df = extend_df[['strategy_key', 'dt_version', 'gh_id', 'sort', 'video_id', 'score']]
    return result_df


def rank_for_base(run_dt, run_hour, project, stats_table, rank_table, stg_key):
    stats_df = process_reply_stats(project, stats_table, STATS_PERIOD_DAYS, run_dt)

    # TODO: support to set base manually
    dt_version = f'{run_dt}{run_hour}'

    # 获取当前base信息, 策略表dt_version(ctime partition)采用当前时间
    base_strategy_df = get_last_strategy_result(
        project, rank_table, dt_version, stg_key)

    default_stats_df = stats_df.query('gh_id == "default"')

    # 在账号内排序,决定该账号(包括default)的base利用内容
    # 排序过程中,确保当前base策略参与排序,因此先关联再过滤
    non_default_ids = list(filter(lambda x: x != 'default', GH_IDS))
    gh_ids_str = ','.join(f'"{x}"' for x in non_default_ids)
    stats_df = stats_df.query(f'gh_id in ({gh_ids_str})')

    stats_with_strategy_df = stats_df \
        .merge(
        base_strategy_df,
        on=['gh_id', 'video_id'],
        how='left') \
        .query('strategy_key.notna() or score > 0.1')

    # 合并default和分账号数据
    grouped_stats_df = pd.concat([default_stats_df, stats_with_strategy_df]).reset_index()

    def set_top_n(group, n=2):
        group_sorted = group.sort_values(by='score', ascending=False)
        top_n = group_sorted.head(n)
        top_n['sort'] = range(1, len(top_n) + 1)
        return top_n

    ranked_df = grouped_stats_df.groupby('gh_id').apply(set_top_n, SEND_N)
    ranked_df = ranked_df.reset_index(drop=True)
    ranked_df['strategy_key'] = stg_key
    ranked_df['dt_version'] = dt_version
    ranked_df = ranked_df[['strategy_key', 'dt_version', 'gh_id', 'sort', 'video_id', 'score']]
    return ranked_df


def check_result_data(df):
    for gh_id in GH_IDS:
        for key in (EXPLORE1_GROUP_NAME, EXPLORE2_GROUP_NAME, BASE_GROUP_NAME):
            sub_df = df.query(f'gh_id == "{gh_id}" and strategy_key == "{key}"')
            n_records = len(sub_df)
            if n_records != SEND_N:
                raise Exception(f"Unexpected record count: {gh_id},{key},{n_records}")


def postprocess_override_by_config(df, dt_version):
    config = json.load(open("configs/3rd_gh_reply_video.json"))
    override_data = {
        'strategy_key': [],
        'gh_id': [],
        'sort': [],
        'video_id': []
    }

    for gh_id in config:
        gh_config = config[gh_id]
        for key in gh_config:
            for video_config in gh_config[key]:
                # remove current
                position = video_config['position']
                video_id = video_config['video_id']
                df = df.drop(df.query(f'gh_id == "{gh_id}" and strategy_key == "{key}" and sort == {position}').index)
                override_data['strategy_key'].append(key)
                override_data['gh_id'].append(gh_id)
                override_data['sort'].append(position)
                override_data['video_id'].append(video_id)
    n_records = len(override_data['strategy_key'])
    override_data['dt_version'] = [dt_version] * n_records
    override_data['score'] = [0.0] * n_records
    df_to_append = pd.DataFrame(override_data)
    df = pd.concat([df, df_to_append], ignore_index=True)
    return df


def rank_for_base_designate(run_dt, run_hour, stg_key):
    dt_version = f'{run_dt}{run_hour}'
    ranked_df = pd.DataFrame()  # 初始化一个空的 DataFrame

    for gh_id in GH_IDS:
        if gh_id in TARGET_GH_IDS:
            temp_df = pd.DataFrame({
                'strategy_key': [stg_key],
                'dt_version': [dt_version],
                'gh_id': [gh_id],
                'sort': [1],
                'video_id': [13586800],
                'score': [0.5]
            })
        else:
            temp_df = pd.DataFrame({
                'strategy_key': [stg_key],
                'dt_version': [dt_version],
                'gh_id': [gh_id],
                'sort': [1],
                'video_id': [20463342],
                'score': [0.5]
            })
        ranked_df = pd.concat([ranked_df, temp_df], ignore_index=True)
    return ranked_df


def build_and_transfer_data(run_dt, run_hour, project, **kwargs):
    dt_version = f'{run_dt}{run_hour}'
    dry_run = kwargs.get('dry_run', False)

    gh_df = get_and_update_gh_ids(run_dt)

    layer1_rank = rank_for_layer1(run_dt, run_hour, ODS_PROJECT, EXPLORE_POOL_TABLE, gh_df)
    layer2_rank = rank_for_layer2(run_dt, run_hour, ODS_PROJECT, GH_REPLY_STATS_TABLE, ODPS_RANK_RESULT_TABLE)
    # base_rank = rank_for_base(run_dt, run_hour, ODS_PROJECT, GH_REPLY_STATS_TABLE, ODPS_RANK_RESULT_TABLE,BASE_GROUP_NAME)
    # layer2_rank = rank_for_base_designate(run_dt, run_hour, EXPLORE2_GROUP_NAME)
    base_rank = rank_for_base_designate(run_dt, run_hour, BASE_GROUP_NAME)

    final_rank_df = pd.concat([layer1_rank, layer2_rank, base_rank]).reset_index(drop=True)

    final_rank_df = postprocess_override_by_config(final_rank_df, dt_version)
    check_result_data(final_rank_df)

    odps_instance = get_odps_instance(project)
    odps_ranked_df = odps.DataFrame(final_rank_df)

    video_df = get_dataframe_from_odps('videoods', 'wx_video')
    video_df['cover_url'] = video_df['cover_img_path'] + CDN_IMG_OPERATOR
    video_df = video_df['id', 'title', 'cover_url']
    final_df = odps_ranked_df.join(video_df, on=('video_id', 'id'))

    final_df = final_df.to_pandas()
    final_df = final_df[['strategy_key', 'dt_version', 'gh_id', 'sort', 'video_id', 'title', 'cover_url', 'score']]

    # reverse sending order
    final_df['sort'] = SEND_N + 1 - final_df['sort']

    if dry_run:
        print(final_df[['strategy_key', 'gh_id', 'sort', 'video_id', 'score', 'title']]
              .sort_values(by=['strategy_key', 'gh_id', 'sort']))
        return

    # save to ODPS
    t = odps_instance.get_table(ODPS_RANK_RESULT_TABLE)
    part_spec_dict = {'dt': run_dt, 'hour': run_hour, 'ctime': dt_version}
    part_spec = ','.join(['{}={}'.format(k, part_spec_dict[k]) for k in part_spec_dict.keys()])
    with t.open_writer(partition=part_spec, create_partition=True, overwrite=True) as writer:
        writer.write(list(final_df.itertuples(index=False)))

    # sync to MySQL
    data_to_insert = [tuple(row) for row in final_df.itertuples(index=False)]
    data_columns = list(final_df.columns)
    mysql = MysqlHelper(CONFIG.MYSQL_CRAWLER_INFO)
    mysql.batch_insert(RDS_RANK_RESULT_TABLE, data_to_insert, data_columns)


def main_loop():
    argparser = ArgumentParser()
    argparser.add_argument('-n', '--dry-run', action='store_true')
    argparser.add_argument('--run-at',help='assume to run at date and hour, yyyyMMddHH')
    args = argparser.parse_args()

    run_date = datetime.today()
    if args.run_at:
        run_date = datetime.strptime(args.run_at, "%Y%m%d%H")
        LOGGER.info(f"Assume to run at {run_date.strftime('%Y-%m-%d %H:00')}")

    try:
        now_date = datetime.today()
        LOGGER.info(f"开始执行: {datetime.strftime(now_date, '%Y-%m-%d %H:%M')}")

        last_date = run_date - timedelta(1)
        last_dt = last_date.strftime("%Y%m%d")
        # 查看当前天级更新的数据是否已准备好
        # 当前上游统计表为天级更新,但字段设计为兼容小时级
        h_data_count = check_data_partition(ODS_PROJECT, GH_REPLY_STATS_TABLE, last_dt, '00')
        if h_data_count > 0:
            LOGGER.info('上游数据表查询数据条数={},开始计算'.format(h_data_count))
            run_dt = run_date.strftime("%Y%m%d")
            run_hour = run_date.strftime("%H")
            LOGGER.info(f'run_dt: {run_dt}, run_hour: {run_hour}')
            build_and_transfer_data(run_dt, run_hour, ODS_PROJECT,
                                    dry_run=args.dry_run)
            LOGGER.info('数据更新完成')
        else:
            LOGGER.info("上游数据未就绪,等待60s")
            Timer(60, main_loop).start()
        return
    except Exception as e:
        LOGGER.error(f"数据更新失败, exception: {e}, traceback: {traceback.format_exc()}")
        if CONFIG.ENV_TEXT == '开发环境':
            return
        send_msg_to_feishu(
            webhook=CONFIG.FEISHU_ROBOT['growth_task_robot'].get('webhook'),
            key_word=CONFIG.FEISHU_ROBOT['growth_task_robot'].get('key_word'),
            msg_text=f"rov-offline{CONFIG.ENV_TEXT} - 数据更新失败\n"
                     f"exception: {e}\n"
                     f"traceback: {traceback.format_exc()}"
        )


if __name__ == '__main__':
    LOGGER.info("%s 开始执行" % os.path.basename(__file__))
    LOGGER.info(f"environment: {CONFIG.ENV_TEXT}")
    main_loop()