extract_pose_v2.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. #!/usr/bin/env python3
  2. """
  3. 提取人体姿态骨骼图 v2 - 使用MediaPipe Pose
  4. 针对不同图片使用不同参数
  5. """
  6. import mediapipe as mp
  7. import cv2
  8. import numpy as np
  9. import json
  10. import os
  11. mp_pose = mp.solutions.pose
  12. mp_drawing = mp.solutions.drawing_utils
  13. def extract_pose(image_path, output_dir, img_id, complexity=2, conf=0.5):
  14. """提取单张图片的姿态骨骼"""
  15. img = cv2.imread(image_path)
  16. if img is None:
  17. return None
  18. h, w = img.shape[:2]
  19. img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  20. with mp_pose.Pose(
  21. static_image_mode=True,
  22. model_complexity=complexity,
  23. enable_segmentation=False,
  24. min_detection_confidence=conf
  25. ) as pose:
  26. results = pose.process(img_rgb)
  27. if not results.pose_landmarks:
  28. return None
  29. # 创建黑色背景的骨骼图
  30. skeleton_img = np.zeros((h, w, 3), dtype=np.uint8)
  31. # 绘制骨骼
  32. mp_drawing.draw_landmarks(
  33. skeleton_img,
  34. results.pose_landmarks,
  35. mp_pose.POSE_CONNECTIONS,
  36. landmark_drawing_spec=mp_drawing.DrawingSpec(
  37. color=(255, 255, 255), thickness=4, circle_radius=6
  38. ),
  39. connection_drawing_spec=mp_drawing.DrawingSpec(
  40. color=(180, 180, 180), thickness=3
  41. )
  42. )
  43. # 保存骨骼图
  44. skeleton_path = os.path.join(output_dir, f"{img_id}_pose_skeleton.png")
  45. cv2.imwrite(skeleton_path, skeleton_img)
  46. # 提取关键点坐标
  47. landmark_names = [lm.name for lm in mp_pose.PoseLandmark]
  48. landmarks_data = {}
  49. for i, landmark in enumerate(results.pose_landmarks.landmark):
  50. name = landmark_names[i] if i < len(landmark_names) else f"landmark_{i}"
  51. landmarks_data[name] = {
  52. "x": round(landmark.x, 4),
  53. "y": round(landmark.y, 4),
  54. "z": round(landmark.z, 4),
  55. "visibility": round(landmark.visibility, 4),
  56. "pixel_x": int(landmark.x * w),
  57. "pixel_y": int(landmark.y * h)
  58. }
  59. # 保存关键点JSON
  60. json_path = os.path.join(output_dir, f"{img_id}_pose_keypoints.json")
  61. with open(json_path, 'w', encoding='utf-8') as f:
  62. json.dump({
  63. "image_id": img_id,
  64. "image_size": {"width": w, "height": h},
  65. "model_complexity": complexity,
  66. "detection_confidence": conf,
  67. "landmarks": landmarks_data,
  68. "skeleton_image": f"{img_id}_pose_skeleton.png"
  69. }, f, ensure_ascii=False, indent=2)
  70. return landmarks_data
  71. def main():
  72. input_dir = "input"
  73. output_dir = "output/features/pose_skeleton"
  74. # 针对不同图片的参数配置
  75. # img_5是手臂特写,无法检测全身姿态,记录为局部
  76. configs = {
  77. "img_1": {"complexity": 2, "conf": 0.5},
  78. "img_2": {"complexity": 0, "conf": 0.1},
  79. "img_3": {"complexity": 2, "conf": 0.5},
  80. "img_4": {"complexity": 2, "conf": 0.5},
  81. "img_5": None, # 手臂特写,无法检测全身
  82. "img_6": {"complexity": 0, "conf": 0.1},
  83. "img_7": {"complexity": 0, "conf": 0.1},
  84. "img_8": {"complexity": 2, "conf": 0.5},
  85. "img_9": {"complexity": 0, "conf": 0.1},
  86. }
  87. results_summary = []
  88. for i in range(1, 10):
  89. img_id = f"img_{i}"
  90. image_path = os.path.join(input_dir, f"{img_id}.jpg")
  91. if not os.path.exists(image_path):
  92. continue
  93. config = configs.get(img_id)
  94. print(f"处理 {img_id}...", end=" ")
  95. if config is None:
  96. print("跳过(局部特写,无法检测全身姿态)")
  97. results_summary.append({"image_id": img_id, "detected": False, "reason": "partial_view"})
  98. continue
  99. landmarks = extract_pose(image_path, output_dir, img_id,
  100. config["complexity"], config["conf"])
  101. if landmarks:
  102. print(f"✓ 骨骼图已保存")
  103. results_summary.append({
  104. "image_id": img_id,
  105. "detected": True,
  106. "keypoints_file": f"{img_id}_pose_keypoints.json",
  107. "skeleton_file": f"{img_id}_pose_skeleton.png"
  108. })
  109. else:
  110. print("✗ 未检测到姿态")
  111. results_summary.append({"image_id": img_id, "detected": False, "reason": "not_detected"})
  112. # 建立mapping.json
  113. pose_segment_map = {
  114. "img_1": {"segment": "段落1.1", "feature": "女性人物姿态(站立侧身作画)", "element": "元素1"},
  115. "img_2": {"segment": "段落2.1", "feature": "女性人物姿态(背对镜头作画)", "element": "元素1"},
  116. "img_3": {"segment": "段落3.1", "feature": "女性人物姿态(跪姿作画)", "element": "元素1"},
  117. "img_4": {"segment": "段落4.1", "feature": "女性人物姿态(侧身面对画架)", "element": "元素1"},
  118. "img_6": {"segment": "段落6.1", "feature": "女性人物姿态(背部特写作画)", "element": "元素1"},
  119. "img_7": {"segment": "段落7.1", "feature": "女性人物姿态(侧颜嗅花)", "element": "元素1"},
  120. "img_8": {"segment": "段落8.1", "feature": "女性人物姿态(侧身面对画架)", "element": "元素1"},
  121. "img_9": {"segment": "段落9.1", "feature": "女性人物姿态(背影远景)", "element": "元素1"},
  122. }
  123. mapping = {
  124. "dimension": "pose_skeleton",
  125. "description": "人体姿态骨骼关键点图,使用MediaPipe Pose提取33个关键点",
  126. "tool": "MediaPipe Pose v0.10.9",
  127. "format": {
  128. "skeleton_image": "PNG,黑色背景,白色骨骼连线,尺寸与原图相同",
  129. "keypoints_json": "JSON,包含33个关键点的归一化坐标(x,y,z)和像素坐标"
  130. },
  131. "mappings": []
  132. }
  133. for result in results_summary:
  134. img_id = result["image_id"]
  135. if result["detected"]:
  136. seg_info = pose_segment_map.get(img_id, {})
  137. mapping["mappings"].append({
  138. "file": result["skeleton_file"],
  139. "keypoints_file": result["keypoints_file"],
  140. "source_image": f"input/{img_id}.jpg",
  141. "segment": seg_info.get("segment", ""),
  142. "category": "实质",
  143. "feature": seg_info.get("feature", ""),
  144. "element_id": seg_info.get("element", "元素1"),
  145. "highlight_cluster": "cluster_1"
  146. })
  147. else:
  148. mapping["mappings"].append({
  149. "file": None,
  150. "source_image": f"input/{img_id}.jpg",
  151. "segment": pose_segment_map.get(img_id, {}).get("segment", ""),
  152. "category": "实质",
  153. "feature": "局部特写,无法提取全身姿态",
  154. "element_id": "元素1",
  155. "note": result.get("reason", "")
  156. })
  157. mapping_path = os.path.join(output_dir, "mapping.json")
  158. with open(mapping_path, 'w', encoding='utf-8') as f:
  159. json.dump(mapping, f, ensure_ascii=False, indent=2)
  160. detected = len([r for r in results_summary if r['detected']])
  161. print(f"\n✓ 完成: {detected}/{len(results_summary)} 张图片检测到姿态")
  162. print(f"✓ mapping.json 已保存")
  163. if __name__ == "__main__":
  164. main()