test_stitch.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import sys
  2. import os
  3. import base64
  4. import io
  5. # 解决 Windows 控制台中文输出问题
  6. if sys.platform == "win32":
  7. sys.stdout.reconfigure(encoding="utf-8")
  8. # 把项目根目录加入 sys.path
  9. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  10. from PIL import Image
  11. from stitch_core import stitch_images
  12. OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "output")
  13. os.makedirs(OUTPUT_DIR, exist_ok=True)
  14. def make_color_image(color: str, width: int = 100, height: int = 80) -> str:
  15. img = Image.new("RGB", (width, height), color)
  16. buf = io.BytesIO()
  17. img.save(buf, format="PNG")
  18. return base64.b64encode(buf.getvalue()).decode()
  19. def save_result(result: dict, filename: str):
  20. img_data = base64.b64decode(result["image"])
  21. path = os.path.join(OUTPUT_DIR, filename)
  22. with open(path, "wb") as f:
  23. f.write(img_data)
  24. print(f" saved: {path}")
  25. def test_horizontal():
  26. print("Test: horizontal stitch...")
  27. red = make_color_image("red", 100, 80)
  28. blue = make_color_image("blue", 120, 80)
  29. result = stitch_images([red, blue], direction="horizontal", spacing=0)
  30. assert result["width"] == 220, f"width should be 220, got {result['width']}"
  31. assert result["height"] == 80, f"height should be 80, got {result['height']}"
  32. save_result(result, "horizontal.png")
  33. print(" [PASS] horizontal")
  34. def test_vertical():
  35. print("Test: vertical stitch...")
  36. green = make_color_image("green", 100, 80)
  37. yellow = make_color_image("yellow", 100, 60)
  38. result = stitch_images([green, yellow], direction="vertical", spacing=10)
  39. assert result["width"] == 100, f"width should be 100, got {result['width']}"
  40. assert result["height"] == 150, f"height should be 150, got {result['height']}"
  41. save_result(result, "vertical.png")
  42. print(" [PASS] vertical")
  43. def test_grid():
  44. print("Test: grid stitch...")
  45. imgs = [make_color_image(c, 50, 50) for c in ["red", "green", "blue", "yellow"]]
  46. result = stitch_images(imgs, direction="grid", columns=2, spacing=5)
  47. expected_w = 50 * 2 + 5 * 1 # 105
  48. expected_h = 50 * 2 + 5 * 1 # 105
  49. assert result["width"] == expected_w, f"width should be {expected_w}, got {result['width']}"
  50. assert result["height"] == expected_h, f"height should be {expected_h}, got {result['height']}"
  51. save_result(result, "grid.png")
  52. print(" [PASS] grid")
  53. def test_resize_fit_width():
  54. print("Test: resize_mode=fit_width...")
  55. img1 = make_color_image("cyan", 100, 80)
  56. img2 = make_color_image("magenta", 60, 40)
  57. result = stitch_images([img1, img2], direction="horizontal", resize_mode="fit_width")
  58. # fit_width unifies to max width=100, result width = 100+100 = 200
  59. assert result["width"] == 200, f"width should be 200, got {result['width']}"
  60. save_result(result, "fit_width.png")
  61. print(" [PASS] fit_width")
  62. def test_min_images_error():
  63. print("Test: error on < 2 images...")
  64. try:
  65. img = make_color_image("red")
  66. stitch_images([img], direction="horizontal")
  67. assert False, "Should raise ValueError"
  68. except ValueError as e:
  69. print(f" [PASS] correct error: {e}")
  70. if __name__ == "__main__":
  71. print("=== Image Stitcher Tests ===\n")
  72. test_horizontal()
  73. test_vertical()
  74. test_grid()
  75. test_resize_fit_width()
  76. test_min_images_error()
  77. print("\nAll tests passed!")