1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- import numpy as np
- def parse_line(line):
- """ 解析每一行数据 """
- parts = line.strip().split("\t")
- label = int(parts[0])
- scores = np.array([float(x) for x in parts[2].strip("[]").split(",")])
- # 找到最大值索引
- max_index = np.argmax(scores)
- # 生成 (label, score) 形式的 aucs
- aucs = np.array([(1 if i == label else 0, scores[i]) for i in range(len(scores))])
- # 生成 (是否为真实 label, 是否为最大值) 的 accuracyRate
- accuracy_rate = np.array([(1 if i == label else 0, 1 if i == max_index else 0) for i in range(len(scores))])
- return aucs, accuracy_rate
- def compute_auc(auc_data):
- """ 计算 AUC """
- num_classes = len(auc_data[0]) # 8 classes
- auc_scores = []
- for i in range(num_classes):
- col_data = np.array([row[i] for row in auc_data]) # 取第 i 列
- labels, scores = col_data[:, 0], col_data[:, 1]
- # 按 scores 降序排序
- sorted_indices = np.argsort(-scores)
- sorted_labels = labels[sorted_indices]
- # 计算正负样本数
- pos = np.sum(sorted_labels == 1)
- neg = len(sorted_labels) - pos
- if pos == 0 or neg == 0:
- auc_scores.append(0.5)
- continue
- # 计算 AUC
- rank_sum = np.sum(np.where(sorted_labels == 1)[0] + 1) # 计算正样本的秩次和
- auc = (rank_sum - pos * (pos + 1) / 2) / (pos * neg)
- auc_scores.append(auc)
- return auc_scores
- def compute_accuracy_rate(acc_data):
- """ 计算 accuracy """
- num_classes = len(acc_data[0]) # 8 classes
- # 全局 accuracy 计算
- acc_flatten = np.vstack(acc_data)
- global_correct = np.sum((acc_flatten[:, 0] == 1) & (acc_flatten[:, 1] == 1))
- total_count = acc_flatten.shape[0]
- global_accuracy = global_correct / total_count
- # 按 label 计算 accuracy
- per_label_accuracy = []
- for i in range(num_classes):
- col_data = np.array([row[i] for row in acc_data]) # 取第 i 列
- correct = np.sum((col_data[:, 0] == 1) & (col_data[:, 1] == 1))
- total_positive = np.sum(col_data[:, 0] == 1)
- per_label_accuracy.append(0 if total_positive == 0 else correct / total_positive)
- return global_accuracy, per_label_accuracy
- if __name__ == "__main__":
- file_path = "/Users/zhao/Desktop/tzld/ros/ros_predict_20250302.txt" # 本地文件路径
- # 读取数据
- with open(file_path, "r") as f:
- data_lines = f.readlines()
- # 解析数据
- parsed_data = [parse_line(line) for line in data_lines]
- auc_data = [item[0] for item in parsed_data]
- acc_data = [item[1] for item in parsed_data]
- # 计算 AUC
- auc_scores = compute_auc(auc_data)
- # 计算 Accuracy
- global_acc, per_label_acc = compute_accuracy_rate(acc_data)
- # 打印结果
- print("AUC Scores:")
- for i, auc in enumerate(auc_scores):
- print(f"Label {i}: AUC = {auc:.4f}")
- print(f"\nGlobal Accuracy: {global_acc:.4f}")
- print("\nPer Label Accuracy:")
- for i, acc in enumerate(per_label_acc):
- print(f"Label {i}: Accuracy = {acc:.4f}")
|