stitch_core.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import base64
  2. import io
  3. import math
  4. from typing import List, Literal
  5. from PIL import Image
  6. def stitch_images(images: List[str], direction: Literal["horizontal", "vertical", "grid"] = "horizontal",
  7. columns: int = 2, spacing: int = 0, background_color: str = "#FFFFFF",
  8. resize_mode: Literal["none", "fit_width", "fit_height"] = "none") -> dict:
  9. if len(images) < 2:
  10. raise ValueError("至少需要 2 张图片")
  11. pil_images = [Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") for img_b64 in images]
  12. if resize_mode == "fit_width":
  13. w = max(img.width for img in pil_images)
  14. pil_images = [img.resize((w, int(img.height * w / img.width)), Image.LANCZOS) if img.width != w else img for img in pil_images]
  15. elif resize_mode == "fit_height":
  16. h = max(img.height for img in pil_images)
  17. pil_images = [img.resize((int(img.width * h / img.height), h), Image.LANCZOS) if img.height != h else img for img in pil_images]
  18. if direction == "horizontal":
  19. w = sum(img.width for img in pil_images) + spacing * (len(pil_images) - 1)
  20. h = max(img.height for img in pil_images)
  21. result = Image.new("RGB", (w, h), background_color)
  22. x = 0
  23. for img in pil_images:
  24. result.paste(img, (x, 0))
  25. x += img.width + spacing
  26. elif direction == "vertical":
  27. w = max(img.width for img in pil_images)
  28. h = sum(img.height for img in pil_images) + spacing * (len(pil_images) - 1)
  29. result = Image.new("RGB", (w, h), background_color)
  30. y = 0
  31. for img in pil_images:
  32. result.paste(img, (0, y))
  33. y += img.height + spacing
  34. else:
  35. rows = math.ceil(len(pil_images) / columns)
  36. max_w, max_h = max(img.width for img in pil_images), max(img.height for img in pil_images)
  37. w = max_w * columns + spacing * (columns - 1)
  38. h = max_h * rows + spacing * (rows - 1)
  39. result = Image.new("RGB", (w, h), background_color)
  40. for i, img in enumerate(pil_images):
  41. x = (i % columns) * (max_w + spacing)
  42. y = (i // columns) * (max_h + spacing)
  43. result.paste(img, (x, y))
  44. buf = io.BytesIO()
  45. result.save(buf, format="PNG")
  46. return {"image": base64.b64encode(buf.getvalue()).decode(), "width": result.width, "height": result.height}