|
@@ -0,0 +1,98 @@
|
|
|
+import numpy as np
|
|
|
+
|
|
|
+
|
|
|
+def parse_line(line):
|
|
|
+ """ 解析每一行数据 """
|
|
|
+ parts = line.strip().split("\t")
|
|
|
+ label = int(parts[0])
|
|
|
+ scores = np.array([float(x) for x in parts[2].strip("[]").split(",")])
|
|
|
+
|
|
|
+ # 找到最大值索引
|
|
|
+ max_index = np.argmax(scores)
|
|
|
+
|
|
|
+ # 生成 (label, score) 形式的 aucs
|
|
|
+ aucs = np.array([(1 if i == label else 0, scores[i]) for i in range(len(scores))])
|
|
|
+
|
|
|
+ # 生成 (是否为真实 label, 是否为最大值) 的 accuracyRate
|
|
|
+ accuracy_rate = np.array([(1 if i == label else 0, 1 if i == max_index else 0) for i in range(len(scores))])
|
|
|
+
|
|
|
+ return aucs, accuracy_rate
|
|
|
+
|
|
|
+
|
|
|
+def compute_auc(auc_data):
|
|
|
+ """ 计算 AUC """
|
|
|
+ num_classes = len(auc_data[0]) # 8 classes
|
|
|
+ auc_scores = []
|
|
|
+
|
|
|
+ for i in range(num_classes):
|
|
|
+ col_data = np.array([row[i] for row in auc_data]) # 取第 i 列
|
|
|
+ labels, scores = col_data[:, 0], col_data[:, 1]
|
|
|
+
|
|
|
+ # 按 scores 降序排序
|
|
|
+ sorted_indices = np.argsort(-scores)
|
|
|
+ sorted_labels = labels[sorted_indices]
|
|
|
+
|
|
|
+ # 计算正负样本数
|
|
|
+ pos = np.sum(sorted_labels == 1)
|
|
|
+ neg = len(sorted_labels) - pos
|
|
|
+ if pos == 0 or neg == 0:
|
|
|
+ auc_scores.append(0.5)
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 计算 AUC
|
|
|
+ rank_sum = np.sum(np.where(sorted_labels == 1)[0] + 1) # 计算正样本的秩次和
|
|
|
+ auc = (rank_sum - pos * (pos + 1) / 2) / (pos * neg)
|
|
|
+ auc_scores.append(auc)
|
|
|
+
|
|
|
+ return auc_scores
|
|
|
+
|
|
|
+
|
|
|
+def compute_accuracy_rate(acc_data):
|
|
|
+ """ 计算 accuracy """
|
|
|
+ num_classes = len(acc_data[0]) # 8 classes
|
|
|
+
|
|
|
+ # 全局 accuracy 计算
|
|
|
+ acc_flatten = np.vstack(acc_data)
|
|
|
+ global_correct = np.sum((acc_flatten[:, 0] == 1) & (acc_flatten[:, 1] == 1))
|
|
|
+ total_count = acc_flatten.shape[0]
|
|
|
+ global_accuracy = global_correct / total_count
|
|
|
+
|
|
|
+ # 按 label 计算 accuracy
|
|
|
+ per_label_accuracy = []
|
|
|
+ for i in range(num_classes):
|
|
|
+ col_data = np.array([row[i] for row in acc_data]) # 取第 i 列
|
|
|
+ correct = np.sum((col_data[:, 0] == 1) & (col_data[:, 1] == 1))
|
|
|
+ total_positive = np.sum(col_data[:, 0] == 1)
|
|
|
+ per_label_accuracy.append(0 if total_positive == 0 else correct / total_positive)
|
|
|
+
|
|
|
+ return global_accuracy, per_label_accuracy
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ file_path = "/Users/zhao/Desktop/tzld/ros/ros_predict_20250302.txt" # 本地文件路径
|
|
|
+
|
|
|
+ # 读取数据
|
|
|
+ with open(file_path, "r") as f:
|
|
|
+ data_lines = f.readlines()
|
|
|
+
|
|
|
+ # 解析数据
|
|
|
+ parsed_data = [parse_line(line) for line in data_lines]
|
|
|
+ auc_data = [item[0] for item in parsed_data]
|
|
|
+ acc_data = [item[1] for item in parsed_data]
|
|
|
+
|
|
|
+ # 计算 AUC
|
|
|
+ auc_scores = compute_auc(auc_data)
|
|
|
+
|
|
|
+ # 计算 Accuracy
|
|
|
+ global_acc, per_label_acc = compute_accuracy_rate(acc_data)
|
|
|
+
|
|
|
+ # 打印结果
|
|
|
+ print("AUC Scores:")
|
|
|
+ for i, auc in enumerate(auc_scores):
|
|
|
+ print(f"Label {i}: AUC = {auc:.4f}")
|
|
|
+
|
|
|
+ print(f"\nGlobal Accuracy: {global_acc:.4f}")
|
|
|
+
|
|
|
+ print("\nPer Label Accuracy:")
|
|
|
+ for i, acc in enumerate(per_label_acc):
|
|
|
+ print(f"Label {i}: Accuracy = {acc:.4f}")
|