ros_multi_class_model_predice_analyse.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import numpy as np
  2. def parse_line(line):
  3. """ 解析每一行数据 """
  4. parts = line.strip().split("\t")
  5. label = int(parts[0])
  6. scores = np.array([float(x) for x in parts[2].strip("[]").split(",")])
  7. # 找到最大值索引
  8. max_index = np.argmax(scores)
  9. # 生成 (label, score) 形式的 aucs
  10. aucs = np.array([(1 if i == label else 0, scores[i]) for i in range(len(scores))])
  11. # 生成 (是否为真实 label, 是否为最大值) 的 accuracyRate
  12. accuracy_rate = np.array([(1 if i == label else 0, 1 if i == max_index else 0) for i in range(len(scores))])
  13. return aucs, accuracy_rate
  14. def compute_auc(auc_data):
  15. """ 计算 AUC """
  16. num_classes = len(auc_data[0]) # 8 classes
  17. auc_scores = []
  18. for i in range(num_classes):
  19. col_data = np.array([row[i] for row in auc_data]) # 取第 i 列
  20. labels, scores = col_data[:, 0], col_data[:, 1]
  21. # 按 scores 降序排序
  22. sorted_indices = np.argsort(-scores)
  23. sorted_labels = labels[sorted_indices]
  24. # 计算正负样本数
  25. pos = np.sum(sorted_labels == 1)
  26. neg = len(sorted_labels) - pos
  27. if pos == 0 or neg == 0:
  28. auc_scores.append(0.5)
  29. continue
  30. # 计算 AUC
  31. rank_sum = np.sum(np.where(sorted_labels == 1)[0] + 1) # 计算正样本的秩次和
  32. auc = (rank_sum - pos * (pos + 1) / 2) / (pos * neg)
  33. auc_scores.append(auc)
  34. return auc_scores
  35. def compute_accuracy_rate(acc_data):
  36. """ 计算 accuracy """
  37. num_classes = len(acc_data[0]) # 8 classes
  38. # 全局 accuracy 计算
  39. acc_flatten = np.vstack(acc_data)
  40. global_correct = np.sum((acc_flatten[:, 0] == 1) & (acc_flatten[:, 1] == 1))
  41. total_count = acc_flatten.shape[0]
  42. global_accuracy = global_correct / total_count
  43. # 按 label 计算 accuracy
  44. per_label_accuracy = []
  45. for i in range(num_classes):
  46. col_data = np.array([row[i] for row in acc_data]) # 取第 i 列
  47. correct = np.sum((col_data[:, 0] == 1) & (col_data[:, 1] == 1))
  48. total_positive = np.sum(col_data[:, 0] == 1)
  49. per_label_accuracy.append(0 if total_positive == 0 else correct / total_positive)
  50. return global_accuracy, per_label_accuracy
  51. if __name__ == "__main__":
  52. file_path = "/Users/zhao/Desktop/tzld/ros/ros_predict_20250302.txt" # 本地文件路径
  53. # 读取数据
  54. with open(file_path, "r") as f:
  55. data_lines = f.readlines()
  56. # 解析数据
  57. parsed_data = [parse_line(line) for line in data_lines]
  58. auc_data = [item[0] for item in parsed_data]
  59. acc_data = [item[1] for item in parsed_data]
  60. # 计算 AUC
  61. auc_scores = compute_auc(auc_data)
  62. # 计算 Accuracy
  63. global_acc, per_label_acc = compute_accuracy_rate(acc_data)
  64. # 打印结果
  65. print("AUC Scores:")
  66. for i, auc in enumerate(auc_scores):
  67. print(f"Label {i}: AUC = {auc:.4f}")
  68. print(f"\nGlobal Accuracy: {global_acc:.4f}")
  69. print("\nPer Label Accuracy:")
  70. for i, acc in enumerate(per_label_acc):
  71. print(f"Label {i}: Accuracy = {acc:.4f}")