1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- import numpy as np
- from sklearn.metrics import roc_auc_score
- def parse_line(line):
- """ 解析每一行数据 """
- parts = line.strip().split("\t")
- label = int(parts[0])
- scores = np.array([float(x) for x in parts[2].strip("[]").split(",")])
- # 找到最大值索引
- max_index = np.argmax(scores)
- # 生成 (label, score) 形式的 aucs
- aucs = np.array([(1 if i == label else 0, scores[i]) for i in range(len(scores))])
- # 生成 (是否为真实 label, 是否为最大值) 的 accuracyRate
- accuracy_rate = np.array([(1 if i == label else 0, 1 if i == max_index else 0) for i in range(len(scores))])
- return aucs, accuracy_rate
- def compute_auc(auc_data):
- """ 计算 AUC 使用 roc_auc_score """
- num_classes = len(auc_data[0]) # 8 classes
- auc_scores = []
- for i in range(num_classes):
- col_data = np.array([row[i] for row in auc_data]) # 取第 i 列
- labels, scores = col_data[:, 0], col_data[:, 1]
- # 计算 AUC
- auc = roc_auc_score(labels, scores)
- auc_scores.append(auc)
- return auc_scores
- def compute_accuracy_rate(acc_data):
- """ 计算 accuracy """
- num_classes = len(acc_data[0]) # 8 classes
- # 全局 accuracy 计算
- acc_flatten = np.vstack(acc_data)
- global_correct = np.sum((acc_flatten[:, 0] == 1) & (acc_flatten[:, 1] == 1))
- total_count = acc_flatten.shape[0] / num_classes
- global_accuracy = global_correct / total_count
- # 按 label 计算 accuracy
- per_label_accuracy = []
- for i in range(num_classes):
- col_data = np.array([row[i] for row in acc_data]) # 取第 i 列
- # 过滤这个分类的数据
- class_all_data = col_data[col_data[:, 1] == 1]
- # 过滤这个分类中预估对的数据
- positive_data = class_all_data[class_all_data[:, 0] == 1]
- class_cnt = class_all_data.shape[0]
- positive_cnt = positive_data.shape[0]
- accuracy = 0 if class_cnt == 0 else positive_cnt / class_cnt
- per_label_accuracy.append(accuracy)
- return global_accuracy, per_label_accuracy
- if __name__ == "__main__":
- file_path = "/Users/zhao/Desktop/tzld/ros/ros_predict_20250302.txt" # 本地文件路径
- # 读取数据
- with open(file_path, "r") as f:
- data_lines = f.readlines()
- # 解析数据
- parsed_data = [parse_line(line) for line in data_lines]
- auc_data = [item[0] for item in parsed_data]
- acc_data = [item[1] for item in parsed_data]
- # 计算 AUC
- auc_scores = compute_auc(auc_data)
- # 计算 Accuracy
- global_acc, per_label_acc = compute_accuracy_rate(acc_data)
- # 打印结果
- print("AUC Scores:")
- for i, auc in enumerate(auc_scores):
- print(f"Label {i}: AUC = {auc:.4f}")
- print(f"\nGlobal Accuracy: {global_acc:.4f}")
- print("\nPer Label Accuracy:")
- for i, acc in enumerate(per_label_acc):
- print(f"Label {i}: Accuracy = {acc:.4f}")
|