Переглянути джерело

feat:添加ros分析脚本

zhaohaipeng 4 місяців тому
батько
коміт
bf4962c65d
1 змінених файлів з 98 додано та 0 видалено
  1. 98 0
      model/ros_multi_class_model_predice_analyse.py

+ 98 - 0
model/ros_multi_class_model_predice_analyse.py

@@ -0,0 +1,98 @@
+import numpy as np
+
+
+def parse_line(line):
+    """ 解析每一行数据 """
+    parts = line.strip().split("\t")
+    label = int(parts[0])
+    scores = np.array([float(x) for x in parts[2].strip("[]").split(",")])
+
+    # 找到最大值索引
+    max_index = np.argmax(scores)
+
+    # 生成 (label, score) 形式的 aucs
+    aucs = np.array([(1 if i == label else 0, scores[i]) for i in range(len(scores))])
+
+    # 生成 (是否为真实 label, 是否为最大值) 的 accuracyRate
+    accuracy_rate = np.array([(1 if i == label else 0, 1 if i == max_index else 0) for i in range(len(scores))])
+
+    return aucs, accuracy_rate
+
+
+def compute_auc(auc_data):
+    """ 计算 AUC """
+    num_classes = len(auc_data[0])  # 8 classes
+    auc_scores = []
+
+    for i in range(num_classes):
+        col_data = np.array([row[i] for row in auc_data])  # 取第 i 列
+        labels, scores = col_data[:, 0], col_data[:, 1]
+
+        # 按 scores 降序排序
+        sorted_indices = np.argsort(-scores)
+        sorted_labels = labels[sorted_indices]
+
+        # 计算正负样本数
+        pos = np.sum(sorted_labels == 1)
+        neg = len(sorted_labels) - pos
+        if pos == 0 or neg == 0:
+            auc_scores.append(0.5)
+            continue
+
+        # 计算 AUC
+        rank_sum = np.sum(np.where(sorted_labels == 1)[0] + 1)  # 计算正样本的秩次和
+        auc = (rank_sum - pos * (pos + 1) / 2) / (pos * neg)
+        auc_scores.append(auc)
+
+    return auc_scores
+
+
+def compute_accuracy_rate(acc_data):
+    """ 计算 accuracy """
+    num_classes = len(acc_data[0])  # 8 classes
+
+    # 全局 accuracy 计算
+    acc_flatten = np.vstack(acc_data)
+    global_correct = np.sum((acc_flatten[:, 0] == 1) & (acc_flatten[:, 1] == 1))
+    total_count = acc_flatten.shape[0]
+    global_accuracy = global_correct / total_count
+
+    # 按 label 计算 accuracy
+    per_label_accuracy = []
+    for i in range(num_classes):
+        col_data = np.array([row[i] for row in acc_data])  # 取第 i 列
+        correct = np.sum((col_data[:, 0] == 1) & (col_data[:, 1] == 1))
+        total_positive = np.sum(col_data[:, 0] == 1)
+        per_label_accuracy.append(0 if total_positive == 0 else correct / total_positive)
+
+    return global_accuracy, per_label_accuracy
+
+
+if __name__ == "__main__":
+    file_path = "/Users/zhao/Desktop/tzld/ros/ros_predict_20250302.txt"  # 本地文件路径
+
+    # 读取数据
+    with open(file_path, "r") as f:
+        data_lines = f.readlines()
+
+    # 解析数据
+    parsed_data = [parse_line(line) for line in data_lines]
+    auc_data = [item[0] for item in parsed_data]
+    acc_data = [item[1] for item in parsed_data]
+
+    # 计算 AUC
+    auc_scores = compute_auc(auc_data)
+
+    # 计算 Accuracy
+    global_acc, per_label_acc = compute_accuracy_rate(acc_data)
+
+    # 打印结果
+    print("AUC Scores:")
+    for i, auc in enumerate(auc_scores):
+        print(f"Label {i}: AUC = {auc:.4f}")
+
+    print(f"\nGlobal Accuracy: {global_acc:.4f}")
+
+    print("\nPer Label Accuracy:")
+    for i, acc in enumerate(per_label_acc):
+        print(f"Label {i}: Accuracy = {acc:.4f}")