| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- import base64
- import io
- import math
- import httpx
- import re
- from typing import List, Literal
- from PIL import Image
- def _fetch_image(img_data: str) -> Image.Image:
- if img_data.startswith("http://") or img_data.startswith("https://"):
- resp = httpx.get(img_data, timeout=30.0)
- resp.raise_for_status()
- return Image.open(io.BytesIO(resp.content)).convert("RGB")
- else:
- # Strip data URL if present
- s = img_data.strip()
- m = re.match(r"^data:[^;]+;base64,(.+)$", s, re.I | re.S)
- raw = m.group(1) if m else s
- return Image.open(io.BytesIO(base64.b64decode(raw))).convert("RGB")
- def stitch_images(images: List[str], direction: Literal["horizontal", "vertical", "grid"] = "horizontal",
- columns: int = 2, spacing: int = 0, background_color: str = "#FFFFFF",
- resize_mode: Literal["none", "fit_width", "fit_height"] = "none") -> dict:
- if len(images) < 2:
- raise ValueError("至少需要 2 张图片")
-
- pil_images = [_fetch_image(img_data) for img_data in images]
-
- if resize_mode == "fit_width":
- w = max(img.width for img in pil_images)
- pil_images = [img.resize((w, int(img.height * w / img.width)), Image.LANCZOS) if img.width != w else img for img in pil_images]
- elif resize_mode == "fit_height":
- h = max(img.height for img in pil_images)
- pil_images = [img.resize((int(img.width * h / img.height), h), Image.LANCZOS) if img.height != h else img for img in pil_images]
-
- if direction == "horizontal":
- w = sum(img.width for img in pil_images) + spacing * (len(pil_images) - 1)
- h = max(img.height for img in pil_images)
- result = Image.new("RGB", (w, h), background_color)
- x = 0
- for img in pil_images:
- result.paste(img, (x, 0))
- x += img.width + spacing
- elif direction == "vertical":
- w = max(img.width for img in pil_images)
- h = sum(img.height for img in pil_images) + spacing * (len(pil_images) - 1)
- result = Image.new("RGB", (w, h), background_color)
- y = 0
- for img in pil_images:
- result.paste(img, (0, y))
- y += img.height + spacing
- else:
- rows = math.ceil(len(pil_images) / columns)
- max_w, max_h = max(img.width for img in pil_images), max(img.height for img in pil_images)
- w = max_w * columns + spacing * (columns - 1)
- h = max_h * rows + spacing * (rows - 1)
- result = Image.new("RGB", (w, h), background_color)
- for i, img in enumerate(pil_images):
- x = (i % columns) * (max_w + spacing)
- y = (i // columns) * (max_h + spacing)
- result.paste(img, (x, y))
-
- buf = io.BytesIO()
- result.save(buf, format="PNG")
- return {"image": base64.b64encode(buf.getvalue()).decode(), "width": result.width, "height": result.height}
|