import argparse
import gzip
import sys

import pandas as pd
from hdfs import InsecureClient

client = InsecureClient("http://master-1-1.c-7f31a3eea195cb73.cn-hangzhou.emr.aliyuncs.com:9870", user="spark")


def read_predict(hdfs_path: str) -> list:
    result = []
    for file in client.list(hdfs_path):
        with client.read(hdfs_path + file) as reader:
            with gzip.GzipFile(fileobj=reader, mode="rb") as gz_file:
                for line in gz_file.read().decode("utf-8").split("\n"):
                    split = line.split("\t")
                    if len(split) != 4:
                        continue
                    cid = split[3].split("_")[0]
                    label = int(split[0])
                    score = float(split[2].replace("[", "").replace("]", "").split(",")[1])

                    result.append({
                        "cid": cid,
                        "label": label,
                        "score": score
                    })

    return result


def _main(model1_predict_path: str, model2_predict_path: str, file: str):
    if not model1_predict_path.endswith("/"):
        model1_predict_path += "/"

    if not model2_predict_path.endswith("/"):
        model2_predict_path += "/"

    # # 设置 pandas 显示选项
    # pd.set_option('display.max_rows', None)  # 显示所有行
    # pd.set_option('display.max_columns', None)  # 显示所有列

    model1_result = read_predict(model1_predict_path)
    model2_result = read_predict(model2_predict_path)

    m1 = pd.DataFrame(model1_result)
    g1 = m1.groupby("cid").agg(
        view=('cid', 'size'),
        conv=('label', 'sum'),
        old_score_avg=('score', lambda x: round(x.mean(), 6))
    ).reset_index()

    g1['true'] = g1['conv'] / g1['view']

    m2 = pd.DataFrame(model2_result)
    g2 = m2.groupby("cid").agg(
        new_score_avg=('score', lambda x: round(x.mean(), 6))
    )

    merged = pd.merge(g1, g2, on='cid', how='left')
    merged.fillna(0, inplace=True)

    merged["abs((new-true)/true)"] = abs(
        (merged['new_score_avg'] - merged['true']) / merged['true']
    ).mask(merged['true'] == 0, 0)

    merged["abs((old-true)/true)"] = abs(
        (merged['old_score_avg'] - merged['true']) / merged['true']
    ).mask(merged['true'] == 0, 0)

    merged = merged[['cid', 'view', "conv", "true", "old_score_avg", "new_score_avg",
                     "abs((old-true)/true)", "abs((new-true)/true)"]]
    merged = merged.sort_values(by=['view'], ascending=False)

    with open(file, "w") as writer:
        writer.write(merged.to_string(index=False))
    print("0")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="model_predict_analyse.py")
    parser.add_argument("-p", "--predict_path_list", nargs='*',
                        help="模型评估结果保存路径,第一个为老模型评估结果,第二个为新模型评估结果")
    parser.add_argument("-f", "--file", help="最后计算结果的保存路径")
    args = parser.parse_args()

    predict_path_list = args.predict_path_list
    # 判断参数是否正常
    if len(predict_path_list) != 2:
        sys.exit(1)
    _main(predict_path_list[0], predict_path_list[1], args.file)