import numpy as np def parse_line(line): """ 解析每一行数据 """ parts = line.strip().split("\t") label = int(parts[0]) scores = np.array([float(x) for x in parts[2].strip("[]").split(",")]) # 找到最大值索引 max_index = np.argmax(scores) # 生成 (label, score) 形式的 aucs aucs = np.array([(1 if i == label else 0, scores[i]) for i in range(len(scores))]) # 生成 (是否为真实 label, 是否为最大值) 的 accuracyRate accuracy_rate = np.array([(1 if i == label else 0, 1 if i == max_index else 0) for i in range(len(scores))]) return aucs, accuracy_rate def compute_auc(auc_data): """ 计算 AUC """ num_classes = len(auc_data[0]) # 8 classes auc_scores = [] for i in range(num_classes): col_data = np.array([row[i] for row in auc_data]) # 取第 i 列 labels, scores = col_data[:, 0], col_data[:, 1] # 按 scores 降序排序 sorted_indices = np.argsort(-scores) sorted_labels = labels[sorted_indices] # 计算正负样本数 pos = np.sum(sorted_labels == 1) neg = len(sorted_labels) - pos if pos == 0 or neg == 0: auc_scores.append(0.5) continue # 计算 AUC rank_sum = np.sum(np.where(sorted_labels == 1)[0] + 1) # 计算正样本的秩次和 auc = (rank_sum - pos * (pos + 1) / 2) / (pos * neg) auc_scores.append(auc) return auc_scores def compute_accuracy_rate(acc_data): """ 计算 accuracy """ num_classes = len(acc_data[0]) # 8 classes # 全局 accuracy 计算 acc_flatten = np.vstack(acc_data) global_correct = np.sum((acc_flatten[:, 0] == 1) & (acc_flatten[:, 1] == 1)) total_count = acc_flatten.shape[0] global_accuracy = global_correct / total_count # 按 label 计算 accuracy per_label_accuracy = [] for i in range(num_classes): col_data = np.array([row[i] for row in acc_data]) # 取第 i 列 correct = np.sum((col_data[:, 0] == 1) & (col_data[:, 1] == 1)) total_positive = np.sum(col_data[:, 0] == 1) per_label_accuracy.append(0 if total_positive == 0 else correct / total_positive) return global_accuracy, per_label_accuracy if __name__ == "__main__": file_path = "/Users/zhao/Desktop/tzld/ros/ros_predict_20250302.txt" # 本地文件路径 # 读取数据 with open(file_path, "r") as f: data_lines = f.readlines() # 解析数据 parsed_data = [parse_line(line) for line in data_lines] auc_data = [item[0] for item in parsed_data] acc_data = [item[1] for item in parsed_data] # 计算 AUC auc_scores = compute_auc(auc_data) # 计算 Accuracy global_acc, per_label_acc = compute_accuracy_rate(acc_data) # 打印结果 print("AUC Scores:") for i, auc in enumerate(auc_scores): print(f"Label {i}: AUC = {auc:.4f}") print(f"\nGlobal Accuracy: {global_acc:.4f}") print("\nPer Label Accuracy:") for i, acc in enumerate(per_label_acc): print(f"Label {i}: Accuracy = {acc:.4f}")