extract_features.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. """
  2. 多模态特征提取脚本
  3. 针对"户外白裙写生少女"图片组,提取以下维度:
  4. 1. openpose_skeleton - 人体姿态骨架(OpenPose)
  5. 2. depth_map - 深度图(MiDaS)
  6. 3. lineart_edge - 线稿/边缘图(Lineart)
  7. 4. color_palette - 色彩调色板(ColorThief + 自定义)
  8. 5. bokeh_mask - 景深虚化遮罩(基于深度图推导)
  9. 6. semantic_segmentation - 语义分割(基于颜色聚类)
  10. 7. color_distribution - 色彩分布向量(HSV直方图)
  11. """
  12. import os
  13. import json
  14. import warnings
  15. warnings.filterwarnings('ignore')
  16. import numpy as np
  17. from PIL import Image
  18. import cv2
  19. from colorthief import ColorThief
  20. # 设置工作目录
  21. WORKDIR = os.path.dirname(os.path.abspath(__file__))
  22. INPUT_DIR = os.path.join(WORKDIR, 'input')
  23. OUTPUT_DIR = os.path.join(WORKDIR, 'output', 'features')
  24. print("=== 开始多模态特征提取 ===")
  25. print(f"工作目录: {WORKDIR}")
  26. # ============================================================
  27. # 维度1: OpenPose 人体姿态骨架
  28. # ============================================================
  29. def extract_openpose(img_path, output_path):
  30. """提取人体姿态骨架图"""
  31. from controlnet_aux import OpenposeDetector
  32. detector = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
  33. img = Image.open(img_path)
  34. result = detector(img, hand_and_face=True)
  35. result.save(output_path)
  36. print(f" OpenPose saved: {output_path}")
  37. return True
  38. # ============================================================
  39. # 维度2: MiDaS 深度图
  40. # ============================================================
  41. def extract_depth(img_path, output_path, npy_path=None):
  42. """提取深度图"""
  43. from controlnet_aux import MidasDetector
  44. detector = MidasDetector.from_pretrained('lllyasviel/Annotators')
  45. img = Image.open(img_path)
  46. result = detector(img)
  47. result.save(output_path)
  48. print(f" Depth map saved: {output_path}")
  49. return True
  50. # ============================================================
  51. # 维度3: Lineart 线稿
  52. # ============================================================
  53. def extract_lineart(img_path, output_path):
  54. """提取线稿"""
  55. from controlnet_aux import LineartDetector
  56. detector = LineartDetector.from_pretrained('lllyasviel/Annotators')
  57. img = Image.open(img_path)
  58. result = detector(img, coarse=False)
  59. result.save(output_path)
  60. print(f" Lineart saved: {output_path}")
  61. return True
  62. # ============================================================
  63. # 维度4: 色彩调色板
  64. # ============================================================
  65. def extract_color_palette(img_path, output_path_json, output_path_png, n_colors=8):
  66. """提取主色调调色板"""
  67. # 使用ColorThief提取主色
  68. ct = ColorThief(img_path)
  69. palette = ct.get_palette(color_count=n_colors, quality=1)
  70. # 生成调色板可视化图
  71. palette_img = np.zeros((100, n_colors * 100, 3), dtype=np.uint8)
  72. for i, color in enumerate(palette):
  73. palette_img[:, i*100:(i+1)*100] = color
  74. cv2.imwrite(output_path_png, cv2.cvtColor(palette_img, cv2.COLOR_RGB2BGR))
  75. # 保存JSON
  76. palette_data = {
  77. "colors": [{"r": c[0], "g": c[1], "b": c[2],
  78. "hex": "#{:02x}{:02x}{:02x}".format(c[0], c[1], c[2])}
  79. for c in palette],
  80. "dominant_color": {"r": palette[0][0], "g": palette[0][1], "b": palette[0][2],
  81. "hex": "#{:02x}{:02x}{:02x}".format(palette[0][0], palette[0][1], palette[0][2])}
  82. }
  83. with open(output_path_json, 'w') as f:
  84. json.dump(palette_data, f, indent=2, ensure_ascii=False)
  85. print(f" Color palette saved: {output_path_png}")
  86. return palette_data
  87. # ============================================================
  88. # 维度5: 景深虚化遮罩(Bokeh Mask)
  89. # ============================================================
  90. def extract_bokeh_mask(img_path, depth_path, output_path):
  91. """
  92. 基于深度图推导景深虚化遮罩
  93. 亮区=近景清晰区域,暗区=远景虚化区域
  94. """
  95. img = cv2.imread(img_path)
  96. depth = cv2.imread(depth_path, cv2.IMREAD_GRAYSCALE)
  97. if depth is None:
  98. print(f" Warning: depth map not found for {img_path}")
  99. return False
  100. # 深度图归一化 - 调整到原图尺寸
  101. h_orig, w_orig = img.shape[:2]
  102. depth = cv2.resize(depth, (w_orig, h_orig), interpolation=cv2.INTER_LINEAR)
  103. depth_norm = depth.astype(np.float32) / 255.0
  104. # 计算局部清晰度(拉普拉斯方差)
  105. gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  106. laplacian = cv2.Laplacian(gray, cv2.CV_64F)
  107. # 用高斯模糊计算局部清晰度图
  108. blur_map = np.abs(laplacian)
  109. blur_map = cv2.GaussianBlur(blur_map.astype(np.float32), (51, 51), 0)
  110. blur_map = (blur_map - blur_map.min()) / (blur_map.max() - blur_map.min() + 1e-8)
  111. # 结合深度图和清晰度图生成bokeh mask
  112. # 近景(深度值高)且清晰 = 主体区域
  113. bokeh_mask = (blur_map * 0.6 + depth_norm * 0.4) * 255
  114. bokeh_mask = bokeh_mask.astype(np.uint8)
  115. cv2.imwrite(output_path, bokeh_mask)
  116. print(f" Bokeh mask saved: {output_path}")
  117. return True
  118. # ============================================================
  119. # 维度6: 语义分割(基于颜色聚类)
  120. # ============================================================
  121. def extract_semantic_segmentation(img_path, output_path, n_segments=6):
  122. """
  123. 基于颜色聚类的语义分割
  124. 针对本图片组的特点:白裙/绿背景/调色板/画布
  125. """
  126. from sklearn.cluster import KMeans
  127. img = cv2.imread(img_path)
  128. img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  129. # 转换到LAB颜色空间(更符合人眼感知)
  130. img_lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
  131. h, w = img.shape[:2]
  132. # 加入位置信息(x, y坐标)以增强空间连续性
  133. y_coords, x_coords = np.mgrid[0:h, 0:w]
  134. y_norm = y_coords.astype(np.float32) / h * 30 # 位置权重
  135. x_norm = x_coords.astype(np.float32) / w * 30
  136. # 特征向量:LAB颜色 + 位置
  137. features = np.column_stack([
  138. img_lab.reshape(-1, 3).astype(np.float32),
  139. y_norm.reshape(-1, 1),
  140. x_norm.reshape(-1, 1)
  141. ])
  142. # K-means聚类
  143. kmeans = KMeans(n_clusters=n_segments, random_state=42, n_init=3)
  144. labels = kmeans.fit_predict(features)
  145. labels = labels.reshape(h, w)
  146. # 生成彩色分割图
  147. colors = [
  148. [255, 255, 255], # 白色 - 白裙
  149. [34, 139, 34], # 绿色 - 背景草地
  150. [101, 67, 33], # 棕色 - 调色板/画架
  151. [135, 206, 235], # 天蓝 - 天空/远景
  152. [255, 218, 185], # 肤色 - 人物皮肤
  153. [200, 200, 200], # 灰色 - 画布
  154. ]
  155. seg_img = np.zeros((h, w, 3), dtype=np.uint8)
  156. for i in range(n_segments):
  157. mask = labels == i
  158. # 找到该聚类的平均颜色
  159. cluster_color = kmeans.cluster_centers_[i][:3]
  160. # 转回RGB
  161. cluster_lab = np.uint8([[cluster_color]])
  162. cluster_rgb = cv2.cvtColor(cluster_lab, cv2.COLOR_LAB2RGB)[0][0]
  163. seg_img[mask] = cluster_rgb
  164. cv2.imwrite(output_path, cv2.cvtColor(seg_img, cv2.COLOR_RGB2BGR))
  165. print(f" Segmentation saved: {output_path}")
  166. return True
  167. # ============================================================
  168. # 维度7: 色彩分布向量(HSV直方图)
  169. # ============================================================
  170. def extract_color_distribution(img_path, output_path_json, output_path_png):
  171. """
  172. 提取HSV色彩分布向量
  173. 捕捉图片的整体色调特征
  174. """
  175. img = cv2.imread(img_path)
  176. img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
  177. # 计算HSV直方图
  178. h_hist = cv2.calcHist([img_hsv], [0], None, [36], [0, 180]) # 色相
  179. s_hist = cv2.calcHist([img_hsv], [1], None, [32], [0, 256]) # 饱和度
  180. v_hist = cv2.calcHist([img_hsv], [2], None, [32], [0, 256]) # 明度
  181. # 归一化
  182. h_hist = h_hist.flatten() / h_hist.sum()
  183. s_hist = s_hist.flatten() / s_hist.sum()
  184. v_hist = v_hist.flatten() / v_hist.sum()
  185. # 计算统计特征
  186. h, w = img.shape[:2]
  187. total_pixels = h * w
  188. # 白色像素比例(白裙特征)
  189. white_mask = (img_hsv[:,:,1] < 30) & (img_hsv[:,:,2] > 200)
  190. white_ratio = white_mask.sum() / total_pixels
  191. # 绿色像素比例(背景特征)
  192. green_mask = (img_hsv[:,:,0] >= 35) & (img_hsv[:,:,0] <= 85) & (img_hsv[:,:,1] > 50)
  193. green_ratio = green_mask.sum() / total_pixels
  194. # 平均亮度
  195. mean_brightness = img_hsv[:,:,2].mean() / 255.0
  196. # 平均饱和度
  197. mean_saturation = img_hsv[:,:,1].mean() / 255.0
  198. data = {
  199. "h_histogram": h_hist.tolist(),
  200. "s_histogram": s_hist.tolist(),
  201. "v_histogram": v_hist.tolist(),
  202. "statistics": {
  203. "white_ratio": float(white_ratio),
  204. "green_ratio": float(green_ratio),
  205. "mean_brightness": float(mean_brightness),
  206. "mean_saturation": float(mean_saturation)
  207. }
  208. }
  209. with open(output_path_json, 'w') as f:
  210. json.dump(data, f, indent=2)
  211. # 生成可视化图
  212. fig_h = 200
  213. fig_w = 600
  214. vis = np.ones((fig_h, fig_w, 3), dtype=np.uint8) * 240
  215. # 绘制色相直方图(彩色)
  216. bar_w = fig_w // 36
  217. for i, val in enumerate(h_hist):
  218. bar_h = int(val * (fig_h - 20))
  219. hue = int(i * 5) # 0-180
  220. color_hsv = np.uint8([[[hue, 200, 200]]])
  221. color_rgb = cv2.cvtColor(color_hsv, cv2.COLOR_HSV2BGR)[0][0]
  222. x1 = i * bar_w
  223. x2 = x1 + bar_w - 1
  224. y1 = fig_h - bar_h - 10
  225. y2 = fig_h - 10
  226. cv2.rectangle(vis, (x1, y1), (x2, y2), color_rgb.tolist(), -1)
  227. cv2.imwrite(output_path_png, vis)
  228. print(f" Color distribution saved: {output_path_png}")
  229. return data
  230. # ============================================================
  231. # 主执行流程
  232. # ============================================================
  233. print("\n--- 加载检测器 ---")
  234. from controlnet_aux import OpenposeDetector, MidasDetector, LineartDetector
  235. print("Loading OpenPose...")
  236. openpose_detector = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
  237. print("Loading MiDaS...")
  238. midas_detector = MidasDetector.from_pretrained('lllyasviel/Annotators')
  239. print("Loading Lineart...")
  240. lineart_detector = LineartDetector.from_pretrained('lllyasviel/Annotators')
  241. print("All detectors loaded!")
  242. # 处理每张图片
  243. for i in range(1, 10):
  244. img_name = f"img_{i}"
  245. img_path = os.path.join(INPUT_DIR, f"{img_name}.jpg")
  246. print(f"\n=== 处理 {img_name} ===")
  247. # 1. OpenPose
  248. openpose_path = os.path.join(OUTPUT_DIR, 'openpose_skeleton', f"{img_name}.png")
  249. img = Image.open(img_path)
  250. result = openpose_detector(img, hand_and_face=True)
  251. result.save(openpose_path)
  252. print(f" [1/7] OpenPose: {openpose_path}")
  253. # 2. Depth Map
  254. depth_path = os.path.join(OUTPUT_DIR, 'depth_map', f"{img_name}.png")
  255. result = midas_detector(img)
  256. result.save(depth_path)
  257. print(f" [2/7] Depth: {depth_path}")
  258. # 3. Lineart
  259. lineart_path = os.path.join(OUTPUT_DIR, 'lineart_edge', f"{img_name}.png")
  260. result = lineart_detector(img, coarse=False)
  261. result.save(lineart_path)
  262. print(f" [3/7] Lineart: {lineart_path}")
  263. # 4. Color Palette
  264. palette_json = os.path.join(OUTPUT_DIR, 'color_palette', f"{img_name}.json")
  265. palette_png = os.path.join(OUTPUT_DIR, 'color_palette', f"{img_name}.png")
  266. extract_color_palette(img_path, palette_json, palette_png, n_colors=8)
  267. print(f" [4/7] Color Palette: {palette_png}")
  268. # 5. Bokeh Mask
  269. bokeh_path = os.path.join(OUTPUT_DIR, 'bokeh_mask', f"{img_name}.png")
  270. extract_bokeh_mask(img_path, depth_path, bokeh_path)
  271. print(f" [5/7] Bokeh Mask: {bokeh_path}")
  272. # 6. Semantic Segmentation
  273. seg_path = os.path.join(OUTPUT_DIR, 'semantic_segmentation', f"{img_name}.png")
  274. extract_semantic_segmentation(img_path, seg_path, n_segments=6)
  275. print(f" [6/7] Segmentation: {seg_path}")
  276. # 7. Color Distribution
  277. dist_json = os.path.join(OUTPUT_DIR, 'color_distribution', f"{img_name}.json")
  278. dist_png = os.path.join(OUTPUT_DIR, 'color_distribution', f"{img_name}.png")
  279. extract_color_distribution(img_path, dist_json, dist_png)
  280. print(f" [7/7] Color Distribution: {dist_png}")
  281. print("\n=== 所有特征提取完成 ===")