| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152 |
- import base64
- import io
- import math
- from typing import List, Literal
- from PIL import Image
- def stitch_images(images: List[str], direction: Literal["horizontal", "vertical", "grid"] = "horizontal",
- columns: int = 2, spacing: int = 0, background_color: str = "#FFFFFF",
- resize_mode: Literal["none", "fit_width", "fit_height"] = "none") -> dict:
- if len(images) < 2:
- raise ValueError("至少需要 2 张图片")
-
- pil_images = [Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") for img_b64 in images]
-
- if resize_mode == "fit_width":
- w = max(img.width for img in pil_images)
- pil_images = [img.resize((w, int(img.height * w / img.width)), Image.LANCZOS) if img.width != w else img for img in pil_images]
- elif resize_mode == "fit_height":
- h = max(img.height for img in pil_images)
- pil_images = [img.resize((int(img.width * h / img.height), h), Image.LANCZOS) if img.height != h else img for img in pil_images]
-
- if direction == "horizontal":
- w = sum(img.width for img in pil_images) + spacing * (len(pil_images) - 1)
- h = max(img.height for img in pil_images)
- result = Image.new("RGB", (w, h), background_color)
- x = 0
- for img in pil_images:
- result.paste(img, (x, 0))
- x += img.width + spacing
- elif direction == "vertical":
- w = max(img.width for img in pil_images)
- h = sum(img.height for img in pil_images) + spacing * (len(pil_images) - 1)
- result = Image.new("RGB", (w, h), background_color)
- y = 0
- for img in pil_images:
- result.paste(img, (0, y))
- y += img.height + spacing
- else:
- rows = math.ceil(len(pil_images) / columns)
- max_w, max_h = max(img.width for img in pil_images), max(img.height for img in pil_images)
- w = max_w * columns + spacing * (columns - 1)
- h = max_h * rows + spacing * (rows - 1)
- result = Image.new("RGB", (w, h), background_color)
- for i, img in enumerate(pil_images):
- x = (i % columns) * (max_w + spacing)
- y = (i // columns) * (max_h + spacing)
- result.paste(img, (x, y))
-
- buf = io.BytesIO()
- result.save(buf, format="PNG")
- return {"image": base64.b64encode(buf.getvalue()).decode(), "width": result.width, "height": result.height}
|