import numpy as np from sklearn.metrics import roc_auc_score def parse_line(line): """ 解析每一行数据 """ parts = line.strip().split("\t") label = int(parts[0]) scores = np.array([float(x) for x in parts[2].strip("[]").split(",")]) # 找到最大值索引 max_index = np.argmax(scores) # 生成 (label, score) 形式的 aucs aucs = np.array([(1 if i == label else 0, scores[i]) for i in range(len(scores))]) # 生成 (是否为真实 label, 是否为最大值) 的 accuracyRate accuracy_rate = np.array([(1 if i == label else 0, 1 if i == max_index else 0) for i in range(len(scores))]) return aucs, accuracy_rate def compute_auc(auc_data): """ 计算 AUC 使用 roc_auc_score """ num_classes = len(auc_data[0]) # 8 classes auc_scores = [] for i in range(num_classes): col_data = np.array([row[i] for row in auc_data]) # 取第 i 列 labels, scores = col_data[:, 0], col_data[:, 1] # 计算 AUC auc = roc_auc_score(labels, scores) auc_scores.append(auc) return auc_scores def compute_accuracy_rate(acc_data): """ 计算 accuracy """ num_classes = len(acc_data[0]) # 8 classes # 全局 accuracy 计算 acc_flatten = np.vstack(acc_data) global_correct = np.sum((acc_flatten[:, 0] == 1) & (acc_flatten[:, 1] == 1)) total_count = acc_flatten.shape[0] / num_classes global_accuracy = global_correct / total_count # 按 label 计算 accuracy per_label_accuracy = [] for i in range(num_classes): col_data = np.array([row[i] for row in acc_data]) # 取第 i 列 # 过滤这个分类的数据 class_all_data = col_data[col_data[:, 1] == 1] # 过滤这个分类中预估对的数据 positive_data = class_all_data[class_all_data[:, 0] == 1] class_cnt = class_all_data.shape[0] positive_cnt = positive_data.shape[0] accuracy = 0 if class_cnt == 0 else positive_cnt / class_cnt per_label_accuracy.append(accuracy) return global_accuracy, per_label_accuracy if __name__ == "__main__": file_path = "/Users/zhao/Desktop/tzld/ros/ros_predict_20250302.txt" # 本地文件路径 # 读取数据 with open(file_path, "r") as f: data_lines = f.readlines() # 解析数据 parsed_data = [parse_line(line) for line in data_lines] auc_data = [item[0] for item in parsed_data] acc_data = [item[1] for item in parsed_data] # 计算 AUC auc_scores = compute_auc(auc_data) # 计算 Accuracy global_acc, per_label_acc = compute_accuracy_rate(acc_data) # 打印结果 print("AUC Scores:") for i, auc in enumerate(auc_scores): print(f"Label {i}: AUC = {auc:.4f}") print(f"\nGlobal Accuracy: {global_acc:.4f}") print("\nPer Label Accuracy:") for i, acc in enumerate(per_label_acc): print(f"Label {i}: Accuracy = {acc:.4f}")