| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335 |
- """
- 多模态特征提取脚本
- 针对"户外白裙写生少女"图片组,提取以下维度:
- 1. openpose_skeleton - 人体姿态骨架(OpenPose)
- 2. depth_map - 深度图(MiDaS)
- 3. lineart_edge - 线稿/边缘图(Lineart)
- 4. color_palette - 色彩调色板(ColorThief + 自定义)
- 5. bokeh_mask - 景深虚化遮罩(基于深度图推导)
- 6. semantic_segmentation - 语义分割(基于颜色聚类)
- 7. color_distribution - 色彩分布向量(HSV直方图)
- """
- import os
- import json
- import warnings
- warnings.filterwarnings('ignore')
- import numpy as np
- from PIL import Image
- import cv2
- from colorthief import ColorThief
- # 设置工作目录
- WORKDIR = os.path.dirname(os.path.abspath(__file__))
- INPUT_DIR = os.path.join(WORKDIR, 'input')
- OUTPUT_DIR = os.path.join(WORKDIR, 'output', 'features')
- print("=== 开始多模态特征提取 ===")
- print(f"工作目录: {WORKDIR}")
- # ============================================================
- # 维度1: OpenPose 人体姿态骨架
- # ============================================================
- def extract_openpose(img_path, output_path):
- """提取人体姿态骨架图"""
- from controlnet_aux import OpenposeDetector
- detector = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
- img = Image.open(img_path)
- result = detector(img, hand_and_face=True)
- result.save(output_path)
- print(f" OpenPose saved: {output_path}")
- return True
- # ============================================================
- # 维度2: MiDaS 深度图
- # ============================================================
- def extract_depth(img_path, output_path, npy_path=None):
- """提取深度图"""
- from controlnet_aux import MidasDetector
- detector = MidasDetector.from_pretrained('lllyasviel/Annotators')
- img = Image.open(img_path)
- result = detector(img)
- result.save(output_path)
- print(f" Depth map saved: {output_path}")
- return True
- # ============================================================
- # 维度3: Lineart 线稿
- # ============================================================
- def extract_lineart(img_path, output_path):
- """提取线稿"""
- from controlnet_aux import LineartDetector
- detector = LineartDetector.from_pretrained('lllyasviel/Annotators')
- img = Image.open(img_path)
- result = detector(img, coarse=False)
- result.save(output_path)
- print(f" Lineart saved: {output_path}")
- return True
- # ============================================================
- # 维度4: 色彩调色板
- # ============================================================
- def extract_color_palette(img_path, output_path_json, output_path_png, n_colors=8):
- """提取主色调调色板"""
- # 使用ColorThief提取主色
- ct = ColorThief(img_path)
- palette = ct.get_palette(color_count=n_colors, quality=1)
-
- # 生成调色板可视化图
- palette_img = np.zeros((100, n_colors * 100, 3), dtype=np.uint8)
- for i, color in enumerate(palette):
- palette_img[:, i*100:(i+1)*100] = color
-
- cv2.imwrite(output_path_png, cv2.cvtColor(palette_img, cv2.COLOR_RGB2BGR))
-
- # 保存JSON
- palette_data = {
- "colors": [{"r": c[0], "g": c[1], "b": c[2],
- "hex": "#{:02x}{:02x}{:02x}".format(c[0], c[1], c[2])}
- for c in palette],
- "dominant_color": {"r": palette[0][0], "g": palette[0][1], "b": palette[0][2],
- "hex": "#{:02x}{:02x}{:02x}".format(palette[0][0], palette[0][1], palette[0][2])}
- }
- with open(output_path_json, 'w') as f:
- json.dump(palette_data, f, indent=2, ensure_ascii=False)
-
- print(f" Color palette saved: {output_path_png}")
- return palette_data
- # ============================================================
- # 维度5: 景深虚化遮罩(Bokeh Mask)
- # ============================================================
- def extract_bokeh_mask(img_path, depth_path, output_path):
- """
- 基于深度图推导景深虚化遮罩
- 亮区=近景清晰区域,暗区=远景虚化区域
- """
- img = cv2.imread(img_path)
- depth = cv2.imread(depth_path, cv2.IMREAD_GRAYSCALE)
-
- if depth is None:
- print(f" Warning: depth map not found for {img_path}")
- return False
-
- # 深度图归一化 - 调整到原图尺寸
- h_orig, w_orig = img.shape[:2]
- depth = cv2.resize(depth, (w_orig, h_orig), interpolation=cv2.INTER_LINEAR)
- depth_norm = depth.astype(np.float32) / 255.0
-
- # 计算局部清晰度(拉普拉斯方差)
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- laplacian = cv2.Laplacian(gray, cv2.CV_64F)
-
- # 用高斯模糊计算局部清晰度图
- blur_map = np.abs(laplacian)
- blur_map = cv2.GaussianBlur(blur_map.astype(np.float32), (51, 51), 0)
- blur_map = (blur_map - blur_map.min()) / (blur_map.max() - blur_map.min() + 1e-8)
-
- # 结合深度图和清晰度图生成bokeh mask
- # 近景(深度值高)且清晰 = 主体区域
- bokeh_mask = (blur_map * 0.6 + depth_norm * 0.4) * 255
- bokeh_mask = bokeh_mask.astype(np.uint8)
-
- cv2.imwrite(output_path, bokeh_mask)
- print(f" Bokeh mask saved: {output_path}")
- return True
- # ============================================================
- # 维度6: 语义分割(基于颜色聚类)
- # ============================================================
- def extract_semantic_segmentation(img_path, output_path, n_segments=6):
- """
- 基于颜色聚类的语义分割
- 针对本图片组的特点:白裙/绿背景/调色板/画布
- """
- from sklearn.cluster import KMeans
-
- img = cv2.imread(img_path)
- img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
-
- # 转换到LAB颜色空间(更符合人眼感知)
- img_lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
-
- h, w = img.shape[:2]
- # 加入位置信息(x, y坐标)以增强空间连续性
- y_coords, x_coords = np.mgrid[0:h, 0:w]
- y_norm = y_coords.astype(np.float32) / h * 30 # 位置权重
- x_norm = x_coords.astype(np.float32) / w * 30
-
- # 特征向量:LAB颜色 + 位置
- features = np.column_stack([
- img_lab.reshape(-1, 3).astype(np.float32),
- y_norm.reshape(-1, 1),
- x_norm.reshape(-1, 1)
- ])
-
- # K-means聚类
- kmeans = KMeans(n_clusters=n_segments, random_state=42, n_init=3)
- labels = kmeans.fit_predict(features)
- labels = labels.reshape(h, w)
-
- # 生成彩色分割图
- colors = [
- [255, 255, 255], # 白色 - 白裙
- [34, 139, 34], # 绿色 - 背景草地
- [101, 67, 33], # 棕色 - 调色板/画架
- [135, 206, 235], # 天蓝 - 天空/远景
- [255, 218, 185], # 肤色 - 人物皮肤
- [200, 200, 200], # 灰色 - 画布
- ]
-
- seg_img = np.zeros((h, w, 3), dtype=np.uint8)
- for i in range(n_segments):
- mask = labels == i
- # 找到该聚类的平均颜色
- cluster_color = kmeans.cluster_centers_[i][:3]
- # 转回RGB
- cluster_lab = np.uint8([[cluster_color]])
- cluster_rgb = cv2.cvtColor(cluster_lab, cv2.COLOR_LAB2RGB)[0][0]
- seg_img[mask] = cluster_rgb
-
- cv2.imwrite(output_path, cv2.cvtColor(seg_img, cv2.COLOR_RGB2BGR))
- print(f" Segmentation saved: {output_path}")
- return True
- # ============================================================
- # 维度7: 色彩分布向量(HSV直方图)
- # ============================================================
- def extract_color_distribution(img_path, output_path_json, output_path_png):
- """
- 提取HSV色彩分布向量
- 捕捉图片的整体色调特征
- """
- img = cv2.imread(img_path)
- img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
-
- # 计算HSV直方图
- h_hist = cv2.calcHist([img_hsv], [0], None, [36], [0, 180]) # 色相
- s_hist = cv2.calcHist([img_hsv], [1], None, [32], [0, 256]) # 饱和度
- v_hist = cv2.calcHist([img_hsv], [2], None, [32], [0, 256]) # 明度
-
- # 归一化
- h_hist = h_hist.flatten() / h_hist.sum()
- s_hist = s_hist.flatten() / s_hist.sum()
- v_hist = v_hist.flatten() / v_hist.sum()
-
- # 计算统计特征
- h, w = img.shape[:2]
- total_pixels = h * w
-
- # 白色像素比例(白裙特征)
- white_mask = (img_hsv[:,:,1] < 30) & (img_hsv[:,:,2] > 200)
- white_ratio = white_mask.sum() / total_pixels
-
- # 绿色像素比例(背景特征)
- green_mask = (img_hsv[:,:,0] >= 35) & (img_hsv[:,:,0] <= 85) & (img_hsv[:,:,1] > 50)
- green_ratio = green_mask.sum() / total_pixels
-
- # 平均亮度
- mean_brightness = img_hsv[:,:,2].mean() / 255.0
-
- # 平均饱和度
- mean_saturation = img_hsv[:,:,1].mean() / 255.0
-
- data = {
- "h_histogram": h_hist.tolist(),
- "s_histogram": s_hist.tolist(),
- "v_histogram": v_hist.tolist(),
- "statistics": {
- "white_ratio": float(white_ratio),
- "green_ratio": float(green_ratio),
- "mean_brightness": float(mean_brightness),
- "mean_saturation": float(mean_saturation)
- }
- }
-
- with open(output_path_json, 'w') as f:
- json.dump(data, f, indent=2)
-
- # 生成可视化图
- fig_h = 200
- fig_w = 600
- vis = np.ones((fig_h, fig_w, 3), dtype=np.uint8) * 240
-
- # 绘制色相直方图(彩色)
- bar_w = fig_w // 36
- for i, val in enumerate(h_hist):
- bar_h = int(val * (fig_h - 20))
- hue = int(i * 5) # 0-180
- color_hsv = np.uint8([[[hue, 200, 200]]])
- color_rgb = cv2.cvtColor(color_hsv, cv2.COLOR_HSV2BGR)[0][0]
- x1 = i * bar_w
- x2 = x1 + bar_w - 1
- y1 = fig_h - bar_h - 10
- y2 = fig_h - 10
- cv2.rectangle(vis, (x1, y1), (x2, y2), color_rgb.tolist(), -1)
-
- cv2.imwrite(output_path_png, vis)
- print(f" Color distribution saved: {output_path_png}")
- return data
- # ============================================================
- # 主执行流程
- # ============================================================
- print("\n--- 加载检测器 ---")
- from controlnet_aux import OpenposeDetector, MidasDetector, LineartDetector
- print("Loading OpenPose...")
- openpose_detector = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
- print("Loading MiDaS...")
- midas_detector = MidasDetector.from_pretrained('lllyasviel/Annotators')
- print("Loading Lineart...")
- lineart_detector = LineartDetector.from_pretrained('lllyasviel/Annotators')
- print("All detectors loaded!")
- # 处理每张图片
- for i in range(1, 10):
- img_name = f"img_{i}"
- img_path = os.path.join(INPUT_DIR, f"{img_name}.jpg")
-
- print(f"\n=== 处理 {img_name} ===")
-
- # 1. OpenPose
- openpose_path = os.path.join(OUTPUT_DIR, 'openpose_skeleton', f"{img_name}.png")
- img = Image.open(img_path)
- result = openpose_detector(img, hand_and_face=True)
- result.save(openpose_path)
- print(f" [1/7] OpenPose: {openpose_path}")
-
- # 2. Depth Map
- depth_path = os.path.join(OUTPUT_DIR, 'depth_map', f"{img_name}.png")
- result = midas_detector(img)
- result.save(depth_path)
- print(f" [2/7] Depth: {depth_path}")
-
- # 3. Lineart
- lineart_path = os.path.join(OUTPUT_DIR, 'lineart_edge', f"{img_name}.png")
- result = lineart_detector(img, coarse=False)
- result.save(lineart_path)
- print(f" [3/7] Lineart: {lineart_path}")
-
- # 4. Color Palette
- palette_json = os.path.join(OUTPUT_DIR, 'color_palette', f"{img_name}.json")
- palette_png = os.path.join(OUTPUT_DIR, 'color_palette', f"{img_name}.png")
- extract_color_palette(img_path, palette_json, palette_png, n_colors=8)
- print(f" [4/7] Color Palette: {palette_png}")
-
- # 5. Bokeh Mask
- bokeh_path = os.path.join(OUTPUT_DIR, 'bokeh_mask', f"{img_name}.png")
- extract_bokeh_mask(img_path, depth_path, bokeh_path)
- print(f" [5/7] Bokeh Mask: {bokeh_path}")
-
- # 6. Semantic Segmentation
- seg_path = os.path.join(OUTPUT_DIR, 'semantic_segmentation', f"{img_name}.png")
- extract_semantic_segmentation(img_path, seg_path, n_segments=6)
- print(f" [6/7] Segmentation: {seg_path}")
-
- # 7. Color Distribution
- dist_json = os.path.join(OUTPUT_DIR, 'color_distribution', f"{img_name}.json")
- dist_png = os.path.join(OUTPUT_DIR, 'color_distribution', f"{img_name}.png")
- extract_color_distribution(img_path, dist_json, dist_png)
- print(f" [7/7] Color Distribution: {dist_png}")
- print("\n=== 所有特征提取完成 ===")
|