stitch_core.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import base64
  2. import io
  3. import math
  4. import httpx
  5. import re
  6. from typing import List, Literal
  7. from PIL import Image
  8. def _fetch_image(img_data: str) -> Image.Image:
  9. if img_data.startswith("http://") or img_data.startswith("https://"):
  10. resp = httpx.get(img_data, timeout=30.0)
  11. resp.raise_for_status()
  12. return Image.open(io.BytesIO(resp.content)).convert("RGB")
  13. else:
  14. # Strip data URL if present
  15. s = img_data.strip()
  16. m = re.match(r"^data:[^;]+;base64,(.+)$", s, re.I | re.S)
  17. raw = m.group(1) if m else s
  18. return Image.open(io.BytesIO(base64.b64decode(raw))).convert("RGB")
  19. def stitch_images(images: List[str], direction: Literal["horizontal", "vertical", "grid"] = "horizontal",
  20. columns: int = 2, spacing: int = 0, background_color: str = "#FFFFFF",
  21. resize_mode: Literal["none", "fit_width", "fit_height"] = "none") -> dict:
  22. if len(images) < 2:
  23. raise ValueError("至少需要 2 张图片")
  24. pil_images = [_fetch_image(img_data) for img_data in images]
  25. if resize_mode == "fit_width":
  26. w = max(img.width for img in pil_images)
  27. pil_images = [img.resize((w, int(img.height * w / img.width)), Image.LANCZOS) if img.width != w else img for img in pil_images]
  28. elif resize_mode == "fit_height":
  29. h = max(img.height for img in pil_images)
  30. pil_images = [img.resize((int(img.width * h / img.height), h), Image.LANCZOS) if img.height != h else img for img in pil_images]
  31. if direction == "horizontal":
  32. w = sum(img.width for img in pil_images) + spacing * (len(pil_images) - 1)
  33. h = max(img.height for img in pil_images)
  34. result = Image.new("RGB", (w, h), background_color)
  35. x = 0
  36. for img in pil_images:
  37. result.paste(img, (x, 0))
  38. x += img.width + spacing
  39. elif direction == "vertical":
  40. w = max(img.width for img in pil_images)
  41. h = sum(img.height for img in pil_images) + spacing * (len(pil_images) - 1)
  42. result = Image.new("RGB", (w, h), background_color)
  43. y = 0
  44. for img in pil_images:
  45. result.paste(img, (0, y))
  46. y += img.height + spacing
  47. else:
  48. rows = math.ceil(len(pil_images) / columns)
  49. max_w, max_h = max(img.width for img in pil_images), max(img.height for img in pil_images)
  50. w = max_w * columns + spacing * (columns - 1)
  51. h = max_h * rows + spacing * (rows - 1)
  52. result = Image.new("RGB", (w, h), background_color)
  53. for i, img in enumerate(pil_images):
  54. x = (i % columns) * (max_w + spacing)
  55. y = (i // columns) * (max_h + spacing)
  56. result.paste(img, (x, y))
  57. buf = io.BytesIO()
  58. result.save(buf, format="PNG")
  59. return {"image": base64.b64encode(buf.getvalue()).decode(), "width": result.width, "height": result.height}