import argparse
import gzip
import os.path

import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from hdfs import InsecureClient

client = InsecureClient("http://master-1-1.c-7f31a3eea195cb73.cn-hangzhou.emr.aliyuncs.com:9870", user="spark")

SEGMENT_BASE_PATH = os.environ.get("SEGMENT_BASE_PATH", "/dw/recommend/model/36_model_attachment/score_calibration_file")
PREDICT_CACHE_PATH = os.environ.get("PREDICT_CACHE_PATH", "/root/zhaohp/XGB/predict_cache")


def parse_predict_line(line: str) -> [bool, dict]:
    sp = line.replace("\n", "").split("\t")
    if len(sp) == 4:
        label = int(sp[0])
        cid = sp[3].split("_")[0]
        score = float(sp[2].replace("[", "").replace("]", "").split(",")[1])
        return True, {
            "label": label,
            "cid": cid,
            "score": score
        }
    return False, {}


def read_predict_file(file_path: str) -> pd.DataFrame:
    result = []
    if file_path.startswith("/dw"):
        if not file_path.endswith("/"):
            file_path += "/"
        for file in client.list(file_path):
            with client.read(file_path + file) as reader:
                with gzip.GzipFile(fileobj=reader, mode="rb") as gz_file:
                    for line in gz_file.read().decode("utf-8").split("\n"):
                        b, d = parse_predict_line(line)
                        if b: result.append(d)
    else:
        with open(file_path, "r") as f:
            for line in f.readlines():
                b, d = parse_predict_line(line)
                if b: result.append(d)
    return pd.DataFrame(result)


def _main(old_predict_path: str, new_predict_path: str, output_path: str):
    old_df = read_predict_file(old_predict_path)
    new_df = read_predict_file(new_predict_path)

    num_bins = 50
    old_df['p_bin'], _ = pd.qcut(old_df['score'], q=num_bins, duplicates='drop', retbins=True)
    new_df['p_bin'], _ = pd.qcut(new_df['score'], q=num_bins, duplicates='drop', retbins=True)

    quantile_data_old = old_df.groupby('p_bin').agg(
        mean_p=('score', 'mean'),
        mean_y=('label', 'mean')
    ).reset_index()
    quantile_data_new = new_df.groupby('p_bin').agg(
        mean_p=('score', 'mean'),
        mean_y=('label', 'mean')
    ).reset_index()

    predicted_quantiles_old = quantile_data_old['mean_p']
    actual_quantiles_old = quantile_data_old['mean_y']
    predicted_quantiles_new = quantile_data_new['mean_p']
    actual_quantiles_new = quantile_data_new['mean_y']

    plt.figure(figsize=(6, 6))
    plt.plot(predicted_quantiles_old, actual_quantiles_old, ms=3, ls='-', color='blue', label='old')
    plt.plot(predicted_quantiles_new, actual_quantiles_new, ms=3, ls='-', color='red', label='new')
    plt.plot([0, 1], [0, 1], color='gray', linestyle='--', label='Ideal Line')
    plt.xlim(0, 0.02)
    plt.ylim(0, 0.02)
    plt.xlabel('Predicted pCTR')
    plt.ylabel('Actual CTR')
    plt.title('Q-Q Plot for pCTR Calibration')
    plt.legend()
    plt.grid(True)

    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=__file__)
    parser.add_argument("-op", "--old_predict_path", required=True, help="老模型评估结果")
    parser.add_argument("-np", "--new_predict_path", required=True, help="新模型评估结果")
    parser.add_argument('--output', required=True)
    args = parser.parse_args()

    _main(
        old_predict_path=args.old_predict_path,
        new_predict_path=args.new_predict_path,
        output_path=args.output
    )