瀏覽代碼

Add draw_predict_distribution.py

StrayWarrior 4 月之前
父節點
當前提交
85db1adb2c
共有 1 個文件被更改,包括 99 次插入0 次删除
  1. 99 0
      ad/draw_predict_distribution.py

+ 99 - 0
ad/draw_predict_distribution.py

@@ -0,0 +1,99 @@
+import argparse
+import gzip
+import os.path
+
+import pandas as pd
+import numpy as np
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+from hdfs import InsecureClient
+
+client = InsecureClient("http://master-1-1.c-7f31a3eea195cb73.cn-hangzhou.emr.aliyuncs.com:9870", user="spark")
+
+SEGMENT_BASE_PATH = os.environ.get("SEGMENT_BASE_PATH", "/dw/recommend/model/36_model_attachment/score_calibration_file")
+PREDICT_CACHE_PATH = os.environ.get("PREDICT_CACHE_PATH", "/root/zhaohp/XGB/predict_cache")
+
+
+def parse_predict_line(line: str) -> [bool, dict]:
+    sp = line.replace("\n", "").split("\t")
+    if len(sp) == 4:
+        label = int(sp[0])
+        cid = sp[3].split("_")[0]
+        score = float(sp[2].replace("[", "").replace("]", "").split(",")[1])
+        return True, {
+            "label": label,
+            "cid": cid,
+            "score": score
+        }
+    return False, {}
+
+
+def read_predict_file(file_path: str) -> pd.DataFrame:
+    result = []
+    if file_path.startswith("/dw"):
+        if not file_path.endswith("/"):
+            file_path += "/"
+        for file in client.list(file_path):
+            with client.read(file_path + file) as reader:
+                with gzip.GzipFile(fileobj=reader, mode="rb") as gz_file:
+                    for line in gz_file.read().decode("utf-8").split("\n"):
+                        b, d = parse_predict_line(line)
+                        if b: result.append(d)
+    else:
+        with open(file_path, "r") as f:
+            for line in f.readlines():
+                b, d = parse_predict_line(line)
+                if b: result.append(d)
+    return pd.DataFrame(result)
+
+
+def _main(old_predict_path: str, new_predict_path: str, output_path: str):
+    old_df = read_predict_file(old_predict_path)
+    new_df = read_predict_file(new_predict_path)
+
+    num_bins = 20
+    old_df['p_bin'], _ = pd.qcut(old_df['score'], q=num_bins, duplicates='drop', retbins=True)
+    new_df['p_bin'], _ = pd.qcut(new_df['score'], q=num_bins, duplicates='drop', retbins=True)
+
+    quantile_data_old = old_df.groupby('p_bin').agg(
+        mean_p=('score', 'mean'),
+        mean_y=('label', 'mean')
+    ).reset_index()
+    quantile_data_new = new_df.groupby('p_bin').agg(
+        mean_p=('score', 'mean'),
+        mean_y=('label', 'mean')
+    ).reset_index()
+
+    predicted_quantiles_old = quantile_data_old['mean_p']
+    actual_quantiles_old = quantile_data_old['mean_y']
+    predicted_quantiles_new = quantile_data_new['mean_p']
+    actual_quantiles_new = quantile_data_new['mean_y']
+
+    plt.figure(figsize=(6, 6))
+    plt.plot(predicted_quantiles_old, actual_quantiles_old, ms=3, ls='-', color='blue', label='old')
+    plt.plot(predicted_quantiles_new, actual_quantiles_new, ms=3, ls='-', color='red', label='new')
+    plt.plot([0, 1], [0, 1], color='gray', linestyle='--', label='Ideal Line')
+    plt.xlim(0, 0.02)
+    plt.ylim(0, 0.02)
+    plt.xlabel('Predicted pCTR')
+    plt.ylabel('Actual CTR')
+    plt.title('Q-Q Plot for pCTR Calibration')
+    plt.legend()
+    plt.grid(True)
+
+    plt.savefig(output_path, dpi=300, bbox_inches='tight')
+    plt.close()
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description=__file__)
+    parser.add_argument("-op", "--old_predict_path", required=True, help="老模型评估结果")
+    parser.add_argument("-np", "--new_predict_path", required=True, help="新模型评估结果")
+    parser.add_argument('--output', required=True)
+    args = parser.parse_args()
+
+    _main(
+        old_predict_path=args.old_predict_path,
+        new_predict_path=args.new_predict_path,
+        output_path=args.output
+    )