import base64 import io import math from typing import List, Literal from PIL import Image def stitch_images(images: List[str], direction: Literal["horizontal", "vertical", "grid"] = "horizontal", columns: int = 2, spacing: int = 0, background_color: str = "#FFFFFF", resize_mode: Literal["none", "fit_width", "fit_height"] = "none") -> dict: if len(images) < 2: raise ValueError("至少需要 2 张图片") pil_images = [Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") for img_b64 in images] if resize_mode == "fit_width": w = max(img.width for img in pil_images) pil_images = [img.resize((w, int(img.height * w / img.width)), Image.LANCZOS) if img.width != w else img for img in pil_images] elif resize_mode == "fit_height": h = max(img.height for img in pil_images) pil_images = [img.resize((int(img.width * h / img.height), h), Image.LANCZOS) if img.height != h else img for img in pil_images] if direction == "horizontal": w = sum(img.width for img in pil_images) + spacing * (len(pil_images) - 1) h = max(img.height for img in pil_images) result = Image.new("RGB", (w, h), background_color) x = 0 for img in pil_images: result.paste(img, (x, 0)) x += img.width + spacing elif direction == "vertical": w = max(img.width for img in pil_images) h = sum(img.height for img in pil_images) + spacing * (len(pil_images) - 1) result = Image.new("RGB", (w, h), background_color) y = 0 for img in pil_images: result.paste(img, (0, y)) y += img.height + spacing else: rows = math.ceil(len(pil_images) / columns) max_w, max_h = max(img.width for img in pil_images), max(img.height for img in pil_images) w = max_w * columns + spacing * (columns - 1) h = max_h * rows + spacing * (rows - 1) result = Image.new("RGB", (w, h), background_color) for i, img in enumerate(pil_images): x = (i % columns) * (max_w + spacing) y = (i // columns) * (max_h + spacing) result.paste(img, (x, y)) buf = io.BytesIO() result.save(buf, format="PNG") return {"image": base64.b64encode(buf.getvalue()).decode(), "width": result.width, "height": result.height}