jch 1 week ago
parent
commit
1a5dfb41a0
2 changed files with 78 additions and 1 deletions
  1. 4 1
      README.md
  2. 74 0
      src/preprocess/eval_result.py

+ 4 - 1
README.md

@@ -34,4 +34,7 @@
 - API_PORT=8000 CUDA_VISIBLE_DEVICES=0 llamafactory-cli api examples/inference/qwen2_5vl.yaml
 
 ## 1.10 api推理
-- python src/preprocess/qw_api_url_inference.py --input data/user_info_format_test.csv --output_file eval.csv
+- python src/preprocess/qw_api_url_inference.py --input data/user_info_format_test.csv --output_file test_result.csv
+
+## 1.11 评估
+- python src/preprocess/eval_result.py --input_file test_result.csv

+ 74 - 0
src/preprocess/eval_result.py

@@ -0,0 +1,74 @@
+#!/usr/bin/env python
+# coding=utf-8
+
+import argparse
+from collections import defaultdict
+
+import pandas as pd
+
+
+def process(input_file):
+    labels = ['男性', '女性', '未知']
+    detail_dict = dict()
+    label_dict = defaultdict(int)
+    predict_dict = defaultdict(int)
+    for index, row in pd.read_csv(input_file).iterrows():
+        label = row['gender']
+        predict = row['predict']
+        label_dict[label] += 1
+        predict_dict[predict] += 1
+        if label not in detail_dict:
+            detail_dict[label] = defaultdict(int)
+        detail_dict[label][predict] += 1
+
+    total_cnt = 0
+    total_right_cnt = 0
+    gender_label_cnt = 0
+    gender_predict_cnt = 0
+    gender_right_cnt = 0
+    for label in labels:
+        right_cnt = 0
+        one_dict = detail_dict.get(label, dict())
+        for predict in labels:
+            p_cnt = one_dict.get(predict, 0)
+            if label == predict:
+                right_cnt = p_cnt
+            print(label, predict, p_cnt)
+        label_cnt = label_dict[label]
+        predict_cnt = predict_dict[label]
+        precision_rate = 0
+        recall_rate = 0
+        if predict_cnt > 0:
+            precision_rate = right_cnt / predict_cnt
+        if label_cnt > 0:
+            recall_rate = right_cnt / label_cnt
+        total_cnt += label_cnt
+        total_right_cnt += right_cnt
+        if label != '未知':
+            gender_label_cnt += label_cnt
+            gender_predict_cnt += predict_cnt
+            gender_right_cnt += right_cnt
+        print('%s precision(%d/%d) %.4f%%' % (label, right_cnt, predict_cnt, 100 * precision_rate))
+        print('%s recall(%d/%d) %.4f%%' % (label, right_cnt, label_cnt, 100 * recall_rate))
+        print('')
+
+    gender_precision_rate = 0
+    gender_recall_rate = 0
+    if gender_predict_cnt > 0:
+        gender_precision_rate = gender_right_cnt / gender_predict_cnt
+    if gender_label_cnt > 0:
+        gender_recall_rate = gender_right_cnt / gender_label_cnt
+    print('%s accuracy(%d/%d) %.4f%%' % ('all', total_right_cnt, total_cnt, 100 * total_right_cnt / total_cnt))
+    print('%s precision(%d/%d) %.4f%%' % ('gender', gender_right_cnt, gender_predict_cnt, 100 * gender_precision_rate))
+    print('%s recall(%d/%d) %.4f%%' % ('gender', gender_right_cnt, gender_label_cnt, 100 * gender_recall_rate))
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--input_file', required=True, help='input files')
+    args = parser.parse_args()
+    print('\n\n')
+    print(args)
+    print('\n\n')
+
+    process(args.input_file)