|
@@ -0,0 +1,74 @@
|
|
|
+#!/usr/bin/env python
|
|
|
+# coding=utf-8
|
|
|
+
|
|
|
+import argparse
|
|
|
+from collections import defaultdict
|
|
|
+
|
|
|
+import pandas as pd
|
|
|
+
|
|
|
+
|
|
|
+def process(input_file):
|
|
|
+ labels = ['男性', '女性', '未知']
|
|
|
+ detail_dict = dict()
|
|
|
+ label_dict = defaultdict(int)
|
|
|
+ predict_dict = defaultdict(int)
|
|
|
+ for index, row in pd.read_csv(input_file).iterrows():
|
|
|
+ label = row['gender']
|
|
|
+ predict = row['predict']
|
|
|
+ label_dict[label] += 1
|
|
|
+ predict_dict[predict] += 1
|
|
|
+ if label not in detail_dict:
|
|
|
+ detail_dict[label] = defaultdict(int)
|
|
|
+ detail_dict[label][predict] += 1
|
|
|
+
|
|
|
+ total_cnt = 0
|
|
|
+ total_right_cnt = 0
|
|
|
+ gender_label_cnt = 0
|
|
|
+ gender_predict_cnt = 0
|
|
|
+ gender_right_cnt = 0
|
|
|
+ for label in labels:
|
|
|
+ right_cnt = 0
|
|
|
+ one_dict = detail_dict.get(label, dict())
|
|
|
+ for predict in labels:
|
|
|
+ p_cnt = one_dict.get(predict, 0)
|
|
|
+ if label == predict:
|
|
|
+ right_cnt = p_cnt
|
|
|
+ print(label, predict, p_cnt)
|
|
|
+ label_cnt = label_dict[label]
|
|
|
+ predict_cnt = predict_dict[label]
|
|
|
+ precision_rate = 0
|
|
|
+ recall_rate = 0
|
|
|
+ if predict_cnt > 0:
|
|
|
+ precision_rate = right_cnt / predict_cnt
|
|
|
+ if label_cnt > 0:
|
|
|
+ recall_rate = right_cnt / label_cnt
|
|
|
+ total_cnt += label_cnt
|
|
|
+ total_right_cnt += right_cnt
|
|
|
+ if label != '未知':
|
|
|
+ gender_label_cnt += label_cnt
|
|
|
+ gender_predict_cnt += predict_cnt
|
|
|
+ gender_right_cnt += right_cnt
|
|
|
+ print('%s precision(%d/%d) %.4f%%' % (label, right_cnt, predict_cnt, 100 * precision_rate))
|
|
|
+ print('%s recall(%d/%d) %.4f%%' % (label, right_cnt, label_cnt, 100 * recall_rate))
|
|
|
+ print('')
|
|
|
+
|
|
|
+ gender_precision_rate = 0
|
|
|
+ gender_recall_rate = 0
|
|
|
+ if gender_predict_cnt > 0:
|
|
|
+ gender_precision_rate = gender_right_cnt / gender_predict_cnt
|
|
|
+ if gender_label_cnt > 0:
|
|
|
+ gender_recall_rate = gender_right_cnt / gender_label_cnt
|
|
|
+ print('%s accuracy(%d/%d) %.4f%%' % ('all', total_right_cnt, total_cnt, 100 * total_right_cnt / total_cnt))
|
|
|
+ print('%s precision(%d/%d) %.4f%%' % ('gender', gender_right_cnt, gender_predict_cnt, 100 * gender_precision_rate))
|
|
|
+ print('%s recall(%d/%d) %.4f%%' % ('gender', gender_right_cnt, gender_label_cnt, 100 * gender_recall_rate))
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ parser = argparse.ArgumentParser()
|
|
|
+ parser.add_argument('--input_file', required=True, help='input files')
|
|
|
+ args = parser.parse_args()
|
|
|
+ print('\n\n')
|
|
|
+ print(args)
|
|
|
+ print('\n\n')
|
|
|
+
|
|
|
+ process(args.input_file)
|