Jelajahi Sumber

feat:添加评估结果分析脚本

zhaohaipeng 6 bulan lalu
induk
melakukan
a54d3fe9cc
1 mengubah file dengan 38 tambahan dan 9 penghapusan
  1. 38 9
      ad/model_predict_analyse.py

+ 38 - 9
ad/model_predict_analyse.py

@@ -1,17 +1,48 @@
-import argparse
-import sys
 import gzip
+import pandas as pd
 
-from pyspark.sql import SparkSession
+from hdfs import InsecureClient
 
+client = InsecureClient("http://master-1-1.c-7f31a3eea195cb73.cn-hangzhou.emr.aliyuncs.com:9870", user="spark")
 
-def read_predict(hdfs_path: str):
-    df = spark.read.text(hdfs_path)
-    df.show(truncate=False)
+
+def read_predict(hdfs_path: str) -> list:
+    result = []
+    for file in client.list(hdfs_path):
+        with client.read(hdfs_path + file) as reader:
+            with gzip.GzipFile(fileobj=reader, mode="rb") as gz_file:
+                for line in gz_file.read().decode("utf-8").split("\n"):
+                    split = line.split("\t")
+                    if len(split) != 4:
+                        continue
+                    cid = split[3].split("_")[0]
+                    label = split[0]
+                    score = split[2].replace("[", "").replace("]", "").split(",")[1]
+
+                    result.append({
+                        "cid": cid,
+                        "label": label,
+                        "score": score
+                    })
+
+    return result
 
 
 def _main():
-    read_predict("/dw/recommend/model/34_ad_predict_data/20241004_351_0927_1003_1000/*")
+    model1_result = read_predict("/dw/recommend/model/34_ad_predict_data/20241004_351_0927_1003_1000/")
+    model2_result = read_predict("/dw/recommend/model/34_ad_predict_data/20241004_351_0927_1003_1000/")
+
+    m1 = pd.DataFrame(model1_result)
+    g1 = m1.groupby("cid").agg(count=('cid', 'size'), average_value=('score', 'mean'))
+    # 获取出现次数最多的十个 cid
+    most_common_cid1 = g1.nlargest(10, 'count')
+    print(most_common_cid1)
+
+    m2 = pd.DataFrame(model2_result)
+    g2 = m2.groupby("cid").agg(count=('cid', 'size'), average_value=('score', 'mean'))
+    # 获取出现次数最多的十个 cid
+    most_common_cid2 = g2.nlargest(10, 'count')
+    print(most_common_cid2)
 
 
 if __name__ == '__main__':
@@ -23,6 +54,4 @@ if __name__ == '__main__':
     # # 判断参数是否正常
     # if len(predict_path_list) != 2:
     #     sys.exit(1)
-    spark = SparkSession.builder.appName("model_predict_analyse").getOrCreate()
     _main()
-    spark.stop()