extract_features.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. """
  2. 多模态特征提取脚本 - 写生油画图组
  3. 提取维度:
  4. 1. character_reference - 人物参考图(实质,nanobanana用)
  5. 2. pose_skeleton - 人体姿态骨架(DWPose,各图独立)
  6. 3. palette_texture - 调色板颜料质感(实质,裁剪图)
  7. 4. painting_tools - 绘画工具(实质,裁剪图)
  8. 5. natural_background - 自然背景(实质,rembg去主体)
  9. 6. depth_map - 深度图(形式,Depth Anything V2)
  10. 7. color_palette_text - 色彩调色板文字描述(形式)
  11. """
  12. import os
  13. import json
  14. import warnings
  15. warnings.filterwarnings('ignore')
  16. import numpy as np
  17. from PIL import Image, ImageDraw, ImageFont
  18. import cv2
  19. BASE_DIR = "/Users/liuxiaobai/Desktop/Agent/Agent/examples/find knowledge"
  20. INPUT_DIR = os.path.join(BASE_DIR, "input")
  21. OUTPUT_DIR = os.path.join(BASE_DIR, "output/features")
  22. # 确保输出目录存在
  23. for d in ['character_reference', 'pose_skeleton', 'palette_texture',
  24. 'painting_tools', 'natural_background', 'depth_map', 'color_palette_text']:
  25. os.makedirs(os.path.join(OUTPUT_DIR, d), exist_ok=True)
  26. print("=" * 60)
  27. print("步骤1: 加载所有图片")
  28. print("=" * 60)
  29. images = {}
  30. for i in range(1, 10):
  31. path = os.path.join(INPUT_DIR, f"img_{i}.jpg")
  32. img = Image.open(path).convert("RGB")
  33. images[f"img_{i}"] = img
  34. print(f" img_{i}: {img.size}")
  35. # ============================================================
  36. # 维度1: character_reference - 人物参考图
  37. # 策略:从img_7(侧脸特写)提取最清晰的人物面部+身体参考
  38. # 同时从img_6(背部特写)提取背影参考
  39. # ============================================================
  40. print("\n" + "=" * 60)
  41. print("步骤2: 提取人物参考图 (character_reference)")
  42. print("=" * 60)
  43. # img_7是侧脸特写,最能体现人物面部特征
  44. # img_6是背部+耳饰特写
  45. # img_1是全身最完整的侧后方视角
  46. # 保存关键参考图(不做任何修改,直接保存原图)
  47. ref_imgs = {
  48. "img_7_face_reference": images["img_7"], # 侧脸+玫瑰,最清晰面部
  49. "img_6_back_reference": images["img_6"], # 背部特写+耳饰
  50. "img_1_full_reference": images["img_1"], # 全身参考
  51. }
  52. for name, img in ref_imgs.items():
  53. out_path = os.path.join(OUTPUT_DIR, "character_reference", f"{name}.png")
  54. img.save(out_path)
  55. print(f" 保存: {name}.png ({img.size})")
  56. # ============================================================
  57. # 维度2: pose_skeleton - 人体姿态骨架 (DWPose)
  58. # ============================================================
  59. print("\n" + "=" * 60)
  60. print("步骤3: 提取人体姿态骨架 (DWPose)")
  61. print("=" * 60)
  62. try:
  63. from controlnet_aux import DWposeDetector
  64. dwpose = DWposeDetector()
  65. print(" DWPose加载成功")
  66. # 对每张图提取姿态
  67. pose_imgs = ["img_1", "img_2", "img_3", "img_4", "img_8", "img_9"] # 全身/半身图
  68. for img_id in pose_imgs:
  69. img = images[img_id]
  70. try:
  71. pose_result = dwpose(img, detect_resolution=512, image_resolution=img.size[0])
  72. out_path = os.path.join(OUTPUT_DIR, "pose_skeleton", f"{img_id}_dwpose.png")
  73. pose_result.save(out_path)
  74. print(f" ✓ {img_id}: 姿态提取成功")
  75. except Exception as e:
  76. print(f" ✗ {img_id}: {e}")
  77. # 降级:使用OpenPose
  78. try:
  79. from controlnet_aux import OpenposeDetector
  80. openpose = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
  81. pose_result = openpose(img, detect_resolution=512, image_resolution=img.size[0])
  82. out_path = os.path.join(OUTPUT_DIR, "pose_skeleton", f"{img_id}_openpose.png")
  83. pose_result.save(out_path)
  84. print(f" ✓ {img_id}: OpenPose降级成功")
  85. except Exception as e2:
  86. print(f" ✗ {img_id} OpenPose也失败: {e2}")
  87. except Exception as e:
  88. print(f" DWPose加载失败: {e}")
  89. print(" 尝试OpenPose...")
  90. try:
  91. from controlnet_aux import OpenposeDetector
  92. openpose = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
  93. print(" OpenPose加载成功")
  94. pose_imgs = ["img_1", "img_2", "img_3", "img_4", "img_8", "img_9"]
  95. for img_id in pose_imgs:
  96. img = images[img_id]
  97. try:
  98. pose_result = openpose(img, detect_resolution=512, image_resolution=img.size[0])
  99. out_path = os.path.join(OUTPUT_DIR, "pose_skeleton", f"{img_id}_openpose.png")
  100. pose_result.save(out_path)
  101. print(f" ✓ {img_id}: OpenPose成功")
  102. except Exception as e2:
  103. print(f" ✗ {img_id}: {e2}")
  104. except Exception as e3:
  105. print(f" OpenPose也失败: {e3}")
  106. # ============================================================
  107. # 维度3: palette_texture - 调色板颜料质感
  108. # 策略:从img_5(调色板特写)裁剪调色板区域
  109. # ============================================================
  110. print("\n" + "=" * 60)
  111. print("步骤4: 提取调色板颜料质感 (palette_texture)")
  112. print("=" * 60)
  113. # img_5是调色板最清晰的特写
  114. # img_6也有调色板
  115. palette_imgs = {
  116. "img_5_palette_closeup": images["img_5"], # 调色板特写
  117. "img_6_palette_detail": images["img_6"], # 作画特写含调色板
  118. }
  119. for name, img in palette_imgs.items():
  120. out_path = os.path.join(OUTPUT_DIR, "palette_texture", f"{name}.png")
  121. img.save(out_path)
  122. print(f" 保存: {name}.png")
  123. # ============================================================
  124. # 维度4: painting_tools - 绘画工具(画架+画布)
  125. # 策略:从img_4(画架+空白画布最清晰)提取
  126. # ============================================================
  127. print("\n" + "=" * 60)
  128. print("步骤5: 提取绘画工具参考 (painting_tools)")
  129. print("=" * 60)
  130. tool_imgs = {
  131. "img_4_easel_blank_canvas": images["img_4"], # 画架+空白画布
  132. "img_8_easel_with_rose": images["img_8"], # 画架+玫瑰花
  133. "img_3_easel_painting": images["img_3"], # 画架+油画作品
  134. }
  135. for name, img in tool_imgs.items():
  136. out_path = os.path.join(OUTPUT_DIR, "painting_tools", f"{name}.png")
  137. img.save(out_path)
  138. print(f" 保存: {name}.png")
  139. # ============================================================
  140. # 维度5: natural_background - 自然背景
  141. # 策略:使用rembg去除主体,保留背景
  142. # ============================================================
  143. print("\n" + "=" * 60)
  144. print("步骤6: 提取自然背景 (natural_background)")
  145. print("=" * 60)
  146. try:
  147. from rembg import remove
  148. print(" rembg加载成功")
  149. # 选择背景最清晰的图片
  150. bg_imgs = ["img_9", "img_3", "img_1"] # 背景占比大的图
  151. for img_id in bg_imgs:
  152. img = images[img_id]
  153. try:
  154. # 去除前景,保留背景
  155. result = remove(img)
  156. # 将透明区域填充为白色(前景位置),保留背景
  157. bg_array = np.array(result)
  158. # 创建背景蒙版:alpha=0的区域是前景(被去除的),alpha>0是背景
  159. # 实际上rembg去除背景,我们需要反向操作
  160. # 直接保存原图作为背景参考,并保存去背景版本
  161. # 保存原图(背景参考)
  162. out_path = os.path.join(OUTPUT_DIR, "natural_background", f"{img_id}_bg_reference.png")
  163. img.save(out_path)
  164. # 保存去主体版本(背景分离)
  165. out_path2 = os.path.join(OUTPUT_DIR, "natural_background", f"{img_id}_fg_removed.png")
  166. result.save(out_path2)
  167. print(f" ✓ {img_id}: 背景提取成功")
  168. except Exception as e:
  169. print(f" ✗ {img_id}: {e}")
  170. img.save(os.path.join(OUTPUT_DIR, "natural_background", f"{img_id}_bg_reference.png"))
  171. except Exception as e:
  172. print(f" rembg失败: {e}")
  173. # 降级:直接保存背景参考图
  174. for img_id in ["img_9", "img_3", "img_1"]:
  175. images[img_id].save(os.path.join(OUTPUT_DIR, "natural_background", f"{img_id}_bg_reference.png"))
  176. print(f" 降级保存: {img_id}")
  177. # ============================================================
  178. # 维度6: depth_map - 深度图 (Depth Anything V2)
  179. # ============================================================
  180. print("\n" + "=" * 60)
  181. print("步骤7: 提取深度图 (Depth Anything)")
  182. print("=" * 60)
  183. try:
  184. from transformers import pipeline
  185. print(" 加载Depth Anything V2...")
  186. # 使用Depth Anything V2 - 最新最强的单目深度估计模型
  187. depth_pipe = pipeline(
  188. task="depth-estimation",
  189. model="depth-anything/Depth-Anything-V2-Small-hf",
  190. device="cpu"
  191. )
  192. print(" Depth Anything V2加载成功")
  193. # 对所有图提取深度图
  194. for img_id, img in images.items():
  195. try:
  196. result = depth_pipe(img)
  197. depth_img = result["depth"]
  198. # 转换为可视化深度图
  199. depth_array = np.array(depth_img)
  200. # 归一化到0-255
  201. depth_norm = ((depth_array - depth_array.min()) /
  202. (depth_array.max() - depth_array.min()) * 255).astype(np.uint8)
  203. depth_visual = Image.fromarray(depth_norm)
  204. out_path = os.path.join(OUTPUT_DIR, "depth_map", f"{img_id}_depth.png")
  205. depth_visual.save(out_path)
  206. print(f" ✓ {img_id}: 深度图提取成功")
  207. except Exception as e:
  208. print(f" ✗ {img_id}: {e}")
  209. except Exception as e:
  210. print(f" Depth Anything失败: {e}")
  211. print(" 尝试controlnet_aux的MiDaS...")
  212. try:
  213. from controlnet_aux import MidasDetector
  214. midas = MidasDetector.from_pretrained("lllyasviel/Annotators")
  215. print(" MiDaS加载成功")
  216. for img_id, img in images.items():
  217. try:
  218. depth_result = midas(img, detect_resolution=512, image_resolution=img.size[0])
  219. out_path = os.path.join(OUTPUT_DIR, "depth_map", f"{img_id}_midas_depth.png")
  220. depth_result.save(out_path)
  221. print(f" ✓ {img_id}: MiDaS深度图成功")
  222. except Exception as e2:
  223. print(f" ✗ {img_id}: {e2}")
  224. except Exception as e3:
  225. print(f" MiDaS也失败: {e3}")
  226. # ============================================================
  227. # 维度7: color_palette_text - 色彩调色板(文字描述)
  228. # 使用Python提取主色调,生成专业色彩描述
  229. # ============================================================
  230. print("\n" + "=" * 60)
  231. print("步骤8: 提取色彩调色板 (color_palette_text)")
  232. print("=" * 60)
  233. def extract_color_palette(img, n_colors=8):
  234. """提取图片主色调"""
  235. img_small = img.resize((150, 150))
  236. img_array = np.array(img_small).reshape(-1, 3).astype(float)
  237. # K-means聚类
  238. from sklearn.cluster import KMeans
  239. kmeans = KMeans(n_clusters=n_colors, random_state=42, n_init=10)
  240. kmeans.fit(img_array)
  241. colors = kmeans.cluster_centers_.astype(int)
  242. labels = kmeans.labels_
  243. # 计算每个颜色的占比
  244. counts = np.bincount(labels)
  245. percentages = counts / len(labels) * 100
  246. # 按占比排序
  247. sorted_idx = np.argsort(percentages)[::-1]
  248. colors = colors[sorted_idx]
  249. percentages = percentages[sorted_idx]
  250. return colors, percentages
  251. def rgb_to_hex(rgb):
  252. return f"#{rgb[0]:02X}{rgb[1]:02X}{rgb[2]:02X}"
  253. def rgb_to_hsv_desc(rgb):
  254. """将RGB转为HSV并给出描述"""
  255. r, g, b = rgb[0]/255, rgb[1]/255, rgb[2]/255
  256. h, s, v = cv2.cvtColor(np.array([[[rgb[0], rgb[1], rgb[2]]]], dtype=np.uint8),
  257. cv2.COLOR_RGB2HSV)[0][0]
  258. # 色相描述
  259. if s < 30:
  260. if v < 50: hue_name = "black"
  261. elif v > 200: hue_name = "white"
  262. else: hue_name = "gray"
  263. elif h < 15 or h > 165: hue_name = "red"
  264. elif h < 30: hue_name = "orange"
  265. elif h < 45: hue_name = "yellow"
  266. elif h < 75: hue_name = "yellow-green"
  267. elif h < 105: hue_name = "green"
  268. elif h < 120: hue_name = "cyan-green"
  269. elif h < 135: hue_name = "cyan"
  270. elif h < 150: hue_name = "blue-cyan"
  271. elif h < 165: hue_name = "blue"
  272. else: hue_name = "purple"
  273. # 饱和度描述
  274. if s < 50: sat_name = "desaturated"
  275. elif s < 120: sat_name = "muted"
  276. elif s < 200: sat_name = "saturated"
  277. else: sat_name = "vivid"
  278. # 亮度描述
  279. if v < 80: val_name = "dark"
  280. elif v < 160: val_name = "mid-tone"
  281. else: val_name = "light"
  282. return f"{val_name} {sat_name} {hue_name}", int(h)*2, int(s/255*100), int(v/255*100)
  283. try:
  284. from sklearn.cluster import KMeans
  285. color_data = {}
  286. for img_id, img in images.items():
  287. colors, percentages = extract_color_palette(img, n_colors=8)
  288. palette_info = []
  289. for i, (color, pct) in enumerate(zip(colors, percentages)):
  290. desc, h, s, v = rgb_to_hsv_desc(color)
  291. palette_info.append({
  292. "rank": i + 1,
  293. "hex": rgb_to_hex(color),
  294. "rgb": [int(color[0]), int(color[1]), int(color[2])],
  295. "hsv": {"h": h, "s": s, "v": v},
  296. "description": desc,
  297. "percentage": round(float(pct), 1)
  298. })
  299. color_data[img_id] = palette_info
  300. print(f" ✓ {img_id}: 提取{len(palette_info)}个主色调")
  301. for p in palette_info[:3]:
  302. print(f" {p['hex']} ({p['percentage']}%) - {p['description']}")
  303. # 保存色彩数据
  304. out_path = os.path.join(OUTPUT_DIR, "color_palette_text", "all_images_color_palette.json")
  305. with open(out_path, 'w', encoding='utf-8') as f:
  306. json.dump(color_data, f, ensure_ascii=False, indent=2)
  307. print(f"\n 色彩数据已保存: {out_path}")
  308. # 生成色彩可视化图
  309. for img_id, palette in color_data.items():
  310. palette_img = Image.new('RGB', (800, 120), 'white')
  311. draw = ImageDraw.Draw(palette_img)
  312. x = 0
  313. for p in palette[:8]:
  314. w = int(800 * p['percentage'] / 100)
  315. if w < 5: w = 5
  316. color_tuple = tuple(p['rgb'])
  317. draw.rectangle([x, 0, x+w, 80], fill=color_tuple)
  318. x += w
  319. out_path = os.path.join(OUTPUT_DIR, "color_palette_text", f"{img_id}_palette.png")
  320. palette_img.save(out_path)
  321. print(" 色彩可视化图已保存")
  322. except Exception as e:
  323. print(f" 色彩提取失败: {e}")
  324. print("\n" + "=" * 60)
  325. print("特征提取完成!")
  326. print("=" * 60)
  327. # 列出所有输出文件
  328. for dim in os.listdir(OUTPUT_DIR):
  329. dim_path = os.path.join(OUTPUT_DIR, dim)
  330. if os.path.isdir(dim_path):
  331. files = os.listdir(dim_path)
  332. print(f"\n {dim}/")
  333. for f in sorted(files):
  334. fpath = os.path.join(dim_path, f)
  335. size = os.path.getsize(fpath)
  336. print(f" {f} ({size//1024}KB)")