|
@@ -1,4 +1,5 @@
|
|
|
import numpy as np
|
|
|
+from sklearn.metrics import roc_auc_score
|
|
|
|
|
|
|
|
|
def parse_line(line):
|
|
@@ -20,7 +21,7 @@ def parse_line(line):
|
|
|
|
|
|
|
|
|
def compute_auc(auc_data):
|
|
|
- """ 计算 AUC """
|
|
|
+ """ 计算 AUC 使用 roc_auc_score """
|
|
|
num_classes = len(auc_data[0]) # 8 classes
|
|
|
auc_scores = []
|
|
|
|
|
@@ -28,20 +29,8 @@ def compute_auc(auc_data):
|
|
|
col_data = np.array([row[i] for row in auc_data]) # 取第 i 列
|
|
|
labels, scores = col_data[:, 0], col_data[:, 1]
|
|
|
|
|
|
- # 按 scores 降序排序
|
|
|
- sorted_indices = np.argsort(-scores)
|
|
|
- sorted_labels = labels[sorted_indices]
|
|
|
-
|
|
|
- # 计算正负样本数
|
|
|
- pos = np.sum(sorted_labels == 1)
|
|
|
- neg = len(sorted_labels) - pos
|
|
|
- if pos == 0 or neg == 0:
|
|
|
- auc_scores.append(0.5)
|
|
|
- continue
|
|
|
-
|
|
|
# 计算 AUC
|
|
|
- rank_sum = np.sum(np.where(sorted_labels == 1)[0] + 1) # 计算正样本的秩次和
|
|
|
- auc = (rank_sum - pos * (pos + 1) / 2) / (pos * neg)
|
|
|
+ auc = roc_auc_score(labels, scores)
|
|
|
auc_scores.append(auc)
|
|
|
|
|
|
return auc_scores
|
|
@@ -54,16 +43,24 @@ def compute_accuracy_rate(acc_data):
|
|
|
# 全局 accuracy 计算
|
|
|
acc_flatten = np.vstack(acc_data)
|
|
|
global_correct = np.sum((acc_flatten[:, 0] == 1) & (acc_flatten[:, 1] == 1))
|
|
|
- total_count = acc_flatten.shape[0]
|
|
|
+ total_count = acc_flatten.shape[0] / num_classes
|
|
|
global_accuracy = global_correct / total_count
|
|
|
|
|
|
# 按 label 计算 accuracy
|
|
|
per_label_accuracy = []
|
|
|
for i in range(num_classes):
|
|
|
col_data = np.array([row[i] for row in acc_data]) # 取第 i 列
|
|
|
- correct = np.sum((col_data[:, 0] == 1) & (col_data[:, 1] == 1))
|
|
|
- total_positive = np.sum(col_data[:, 0] == 1)
|
|
|
- per_label_accuracy.append(0 if total_positive == 0 else correct / total_positive)
|
|
|
+
|
|
|
+ # 过滤这个分类的数据
|
|
|
+ class_all_data = col_data[col_data[:, 1] == 1]
|
|
|
+ # 过滤这个分类中预估对的数据
|
|
|
+ positive_data = class_all_data[class_all_data[:, 0] == 1]
|
|
|
+
|
|
|
+ class_cnt = class_all_data.shape[0]
|
|
|
+ positive_cnt = positive_data.shape[0]
|
|
|
+
|
|
|
+ accuracy = 0 if class_cnt == 0 else positive_cnt / class_cnt
|
|
|
+ per_label_accuracy.append(accuracy)
|
|
|
|
|
|
return global_accuracy, per_label_accuracy
|
|
|
|