draw_predict_distribution.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import argparse
  2. import gzip
  3. import os.path
  4. import pandas as pd
  5. import numpy as np
  6. import matplotlib
  7. matplotlib.use('Agg')
  8. import matplotlib.pyplot as plt
  9. from hdfs import InsecureClient
  10. client = InsecureClient("http://master-1-1.c-7f31a3eea195cb73.cn-hangzhou.emr.aliyuncs.com:9870", user="spark")
  11. SEGMENT_BASE_PATH = os.environ.get("SEGMENT_BASE_PATH", "/dw/recommend/model/36_model_attachment/score_calibration_file")
  12. PREDICT_CACHE_PATH = os.environ.get("PREDICT_CACHE_PATH", "/root/zhaohp/XGB/predict_cache")
  13. def parse_predict_line(line: str) -> [bool, dict]:
  14. sp = line.replace("\n", "").split("\t")
  15. if len(sp) == 4:
  16. label = int(sp[0])
  17. cid = sp[3].split("_")[0]
  18. score = float(sp[2].replace("[", "").replace("]", "").split(",")[1])
  19. return True, {
  20. "label": label,
  21. "cid": cid,
  22. "score": score
  23. }
  24. return False, {}
  25. def read_predict_file(file_path: str) -> pd.DataFrame:
  26. result = []
  27. if file_path.startswith("/dw"):
  28. if not file_path.endswith("/"):
  29. file_path += "/"
  30. for file in client.list(file_path):
  31. with client.read(file_path + file) as reader:
  32. with gzip.GzipFile(fileobj=reader, mode="rb") as gz_file:
  33. for line in gz_file.read().decode("utf-8").split("\n"):
  34. b, d = parse_predict_line(line)
  35. if b: result.append(d)
  36. else:
  37. with open(file_path, "r") as f:
  38. for line in f.readlines():
  39. b, d = parse_predict_line(line)
  40. if b: result.append(d)
  41. return pd.DataFrame(result)
  42. def _main(old_predict_path: str, new_predict_path: str, output_path: str):
  43. old_df = read_predict_file(old_predict_path)
  44. new_df = read_predict_file(new_predict_path)
  45. num_bins = 50
  46. old_df['p_bin'], _ = pd.qcut(old_df['score'], q=num_bins, duplicates='drop', retbins=True)
  47. new_df['p_bin'], _ = pd.qcut(new_df['score'], q=num_bins, duplicates='drop', retbins=True)
  48. quantile_data_old = old_df.groupby('p_bin').agg(
  49. mean_p=('score', 'mean'),
  50. mean_y=('label', 'mean')
  51. ).reset_index()
  52. quantile_data_new = new_df.groupby('p_bin').agg(
  53. mean_p=('score', 'mean'),
  54. mean_y=('label', 'mean')
  55. ).reset_index()
  56. predicted_quantiles_old = quantile_data_old['mean_p']
  57. actual_quantiles_old = quantile_data_old['mean_y']
  58. predicted_quantiles_new = quantile_data_new['mean_p']
  59. actual_quantiles_new = quantile_data_new['mean_y']
  60. plt.figure(figsize=(6, 6))
  61. plt.plot(predicted_quantiles_old, actual_quantiles_old, ms=3, ls='-', color='blue', label='old')
  62. plt.plot(predicted_quantiles_new, actual_quantiles_new, ms=3, ls='-', color='red', label='new')
  63. plt.plot([0, 1], [0, 1], color='gray', linestyle='--', label='Ideal Line')
  64. plt.xlim(0, 0.02)
  65. plt.ylim(0, 0.02)
  66. plt.xlabel('Predicted pCTR')
  67. plt.ylabel('Actual CTR')
  68. plt.title('Q-Q Plot for pCTR Calibration')
  69. plt.legend()
  70. plt.grid(True)
  71. plt.savefig(output_path, dpi=300, bbox_inches='tight')
  72. plt.close()
  73. if __name__ == '__main__':
  74. parser = argparse.ArgumentParser(description=__file__)
  75. parser.add_argument("-op", "--old_predict_path", required=True, help="老模型评估结果")
  76. parser.add_argument("-np", "--new_predict_path", required=True, help="新模型评估结果")
  77. parser.add_argument('--output', required=True)
  78. args = parser.parse_args()
  79. _main(
  80. old_predict_path=args.old_predict_path,
  81. new_predict_path=args.new_predict_path,
  82. output_path=args.output
  83. )