extract_pose.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. #!/usr/bin/env python3
  2. """
  3. 提取人体姿态骨骼图 - 使用MediaPipe Pose
  4. 输出:每张图片的骨骼关键点图(PNG)+ 关键点坐标(JSON)
  5. """
  6. import mediapipe as mp
  7. import cv2
  8. import numpy as np
  9. import json
  10. import os
  11. from PIL import Image
  12. mp_pose = mp.solutions.pose
  13. mp_drawing = mp.solutions.drawing_utils
  14. mp_drawing_styles = mp.solutions.drawing_styles
  15. def extract_pose(image_path, output_dir, img_id):
  16. """提取单张图片的姿态骨骼"""
  17. img = cv2.imread(image_path)
  18. if img is None:
  19. print(f"无法读取图片: {image_path}")
  20. return None
  21. h, w = img.shape[:2]
  22. img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  23. with mp_pose.Pose(
  24. static_image_mode=True,
  25. model_complexity=2,
  26. enable_segmentation=False,
  27. min_detection_confidence=0.5
  28. ) as pose:
  29. results = pose.process(img_rgb)
  30. if not results.pose_landmarks:
  31. print(f" 未检测到姿态: {img_id}")
  32. return None
  33. # 创建黑色背景的骨骼图
  34. skeleton_img = np.zeros((h, w, 3), dtype=np.uint8)
  35. # 绘制骨骼连接线(白色)
  36. mp_drawing.draw_landmarks(
  37. skeleton_img,
  38. results.pose_landmarks,
  39. mp_pose.POSE_CONNECTIONS,
  40. landmark_drawing_spec=mp_drawing.DrawingSpec(
  41. color=(255, 255, 255), thickness=3, circle_radius=5
  42. ),
  43. connection_drawing_spec=mp_drawing.DrawingSpec(
  44. color=(200, 200, 200), thickness=2
  45. )
  46. )
  47. # 保存骨骼图
  48. skeleton_path = os.path.join(output_dir, f"{img_id}_pose_skeleton.png")
  49. cv2.imwrite(skeleton_path, skeleton_img)
  50. # 提取关键点坐标
  51. landmarks_data = {}
  52. landmark_names = [lm.name for lm in mp_pose.PoseLandmark]
  53. for i, landmark in enumerate(results.pose_landmarks.landmark):
  54. name = landmark_names[i] if i < len(landmark_names) else f"landmark_{i}"
  55. landmarks_data[name] = {
  56. "x": round(landmark.x, 4), # 归一化坐标 [0,1]
  57. "y": round(landmark.y, 4),
  58. "z": round(landmark.z, 4),
  59. "visibility": round(landmark.visibility, 4),
  60. "pixel_x": int(landmark.x * w),
  61. "pixel_y": int(landmark.y * h)
  62. }
  63. # 保存关键点JSON
  64. json_path = os.path.join(output_dir, f"{img_id}_pose_keypoints.json")
  65. with open(json_path, 'w', encoding='utf-8') as f:
  66. json.dump({
  67. "image_id": img_id,
  68. "image_size": {"width": w, "height": h},
  69. "landmarks": landmarks_data,
  70. "skeleton_image": f"{img_id}_pose_skeleton.png"
  71. }, f, ensure_ascii=False, indent=2)
  72. print(f" ✓ {img_id}: 骨骼图已保存 -> {skeleton_path}")
  73. return landmarks_data
  74. def main():
  75. input_dir = "input"
  76. output_dir = "output/features/pose_skeleton"
  77. results_summary = []
  78. for i in range(1, 10):
  79. img_id = f"img_{i}"
  80. image_path = os.path.join(input_dir, f"{img_id}.jpg")
  81. if not os.path.exists(image_path):
  82. print(f"图片不存在: {image_path}")
  83. continue
  84. print(f"处理 {img_id}...")
  85. landmarks = extract_pose(image_path, output_dir, img_id)
  86. if landmarks:
  87. results_summary.append({
  88. "image_id": img_id,
  89. "detected": True,
  90. "keypoints_file": f"{img_id}_pose_keypoints.json",
  91. "skeleton_file": f"{img_id}_pose_skeleton.png"
  92. })
  93. else:
  94. results_summary.append({
  95. "image_id": img_id,
  96. "detected": False
  97. })
  98. # 保存mapping.json
  99. mapping = {
  100. "dimension": "pose_skeleton",
  101. "description": "人体姿态骨骼关键点图,使用MediaPipe Pose提取33个关键点",
  102. "tool": "MediaPipe Pose v0.10.9",
  103. "format": {
  104. "skeleton_image": "PNG,黑色背景,白色骨骼连线",
  105. "keypoints_json": "JSON,包含33个关键点的归一化坐标和像素坐标"
  106. },
  107. "mappings": []
  108. }
  109. # 根据制作表结构建立对应关系
  110. pose_segment_map = {
  111. "img_1": [
  112. {"segment": "段落1.1", "category": "实质", "feature": "女性人物姿态", "element": "元素1"},
  113. ],
  114. "img_2": [
  115. {"segment": "段落2.1", "category": "实质", "feature": "女性人物姿态", "element": "元素1"},
  116. ],
  117. "img_3": [
  118. {"segment": "段落3.1", "category": "实质", "feature": "女性人物姿态(跪姿)", "element": "元素1"},
  119. ],
  120. "img_4": [
  121. {"segment": "段落4.1", "category": "实质", "feature": "女性人物姿态(侧身)", "element": "元素1"},
  122. ],
  123. "img_5": [
  124. {"segment": "段落5.1", "category": "实质", "feature": "女性人物姿态(手臂特写)", "element": "元素1"},
  125. ],
  126. "img_6": [
  127. {"segment": "段落6.1", "category": "实质", "feature": "女性人物姿态(背部特写)", "element": "元素1"},
  128. ],
  129. "img_7": [
  130. {"segment": "段落7.1", "category": "实质", "feature": "女性人物姿态(侧颜/嗅花)", "element": "元素1"},
  131. ],
  132. "img_8": [
  133. {"segment": "段落8.1", "category": "实质", "feature": "女性人物姿态(侧身)", "element": "元素1"},
  134. ],
  135. "img_9": [
  136. {"segment": "段落9.1", "category": "实质", "feature": "女性人物姿态(背影远景)", "element": "元素1"},
  137. ],
  138. }
  139. for result in results_summary:
  140. img_id = result["image_id"]
  141. if result["detected"]:
  142. segments = pose_segment_map.get(img_id, [])
  143. for seg in segments:
  144. mapping["mappings"].append({
  145. "file": result["skeleton_file"],
  146. "keypoints_file": result["keypoints_file"],
  147. "source_image": f"input/{img_id}.jpg",
  148. "segment": seg["segment"],
  149. "category": seg["category"],
  150. "feature": seg["feature"],
  151. "element_id": seg["element"]
  152. })
  153. mapping_path = os.path.join(output_dir, "mapping.json")
  154. with open(mapping_path, 'w', encoding='utf-8') as f:
  155. json.dump(mapping, f, ensure_ascii=False, indent=2)
  156. print(f"\n✓ mapping.json 已保存: {mapping_path}")
  157. print(f"✓ 处理完成: {len([r for r in results_summary if r['detected']])} / {len(results_summary)} 张图片检测到姿态")
  158. if __name__ == "__main__":
  159. main()