ros_multi_class_model_predice_analyse.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import numpy as np
  2. from sklearn.metrics import roc_auc_score
  3. def parse_line(line):
  4. """ 解析每一行数据 """
  5. parts = line.strip().split("\t")
  6. label = int(parts[0])
  7. scores = np.array([float(x) for x in parts[2].strip("[]").split(",")])
  8. # 找到最大值索引
  9. max_index = np.argmax(scores)
  10. # 生成 (label, score) 形式的 aucs
  11. aucs = np.array([(1 if i == label else 0, scores[i]) for i in range(len(scores))])
  12. # 生成 (是否为真实 label, 是否为最大值) 的 accuracyRate
  13. accuracy_rate = np.array([(1 if i == label else 0, 1 if i == max_index else 0) for i in range(len(scores))])
  14. return aucs, accuracy_rate
  15. def compute_auc(auc_data):
  16. """ 计算 AUC 使用 roc_auc_score """
  17. num_classes = len(auc_data[0]) # 8 classes
  18. auc_scores = []
  19. for i in range(num_classes):
  20. col_data = np.array([row[i] for row in auc_data]) # 取第 i 列
  21. labels, scores = col_data[:, 0], col_data[:, 1]
  22. # 计算 AUC
  23. auc = roc_auc_score(labels, scores)
  24. auc_scores.append(auc)
  25. return auc_scores
  26. def compute_accuracy_rate(acc_data):
  27. """ 计算 accuracy """
  28. num_classes = len(acc_data[0]) # 8 classes
  29. # 全局 accuracy 计算
  30. acc_flatten = np.vstack(acc_data)
  31. global_correct = np.sum((acc_flatten[:, 0] == 1) & (acc_flatten[:, 1] == 1))
  32. total_count = acc_flatten.shape[0] / num_classes
  33. global_accuracy = global_correct / total_count
  34. # 按 label 计算 accuracy
  35. per_label_accuracy = []
  36. for i in range(num_classes):
  37. col_data = np.array([row[i] for row in acc_data]) # 取第 i 列
  38. # 过滤这个分类的数据
  39. class_all_data = col_data[col_data[:, 1] == 1]
  40. # 过滤这个分类中预估对的数据
  41. positive_data = class_all_data[class_all_data[:, 0] == 1]
  42. class_cnt = class_all_data.shape[0]
  43. positive_cnt = positive_data.shape[0]
  44. accuracy = 0 if class_cnt == 0 else positive_cnt / class_cnt
  45. per_label_accuracy.append(accuracy)
  46. return global_accuracy, per_label_accuracy
  47. if __name__ == "__main__":
  48. file_path = "/Users/zhao/Desktop/tzld/ros/ros_predict_20250302.txt" # 本地文件路径
  49. # 读取数据
  50. with open(file_path, "r") as f:
  51. data_lines = f.readlines()
  52. # 解析数据
  53. parsed_data = [parse_line(line) for line in data_lines]
  54. auc_data = [item[0] for item in parsed_data]
  55. acc_data = [item[1] for item in parsed_data]
  56. # 计算 AUC
  57. auc_scores = compute_auc(auc_data)
  58. # 计算 Accuracy
  59. global_acc, per_label_acc = compute_accuracy_rate(acc_data)
  60. # 打印结果
  61. print("AUC Scores:")
  62. for i, auc in enumerate(auc_scores):
  63. print(f"Label {i}: AUC = {auc:.4f}")
  64. print(f"\nGlobal Accuracy: {global_acc:.4f}")
  65. print("\nPer Label Accuracy:")
  66. for i, acc in enumerate(per_label_acc):
  67. print(f"Label {i}: Accuracy = {acc:.4f}")