main.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from fastapi import FastAPI, HTTPException
  2. from pydantic import BaseModel, Field
  3. from typing import List, Literal
  4. from stitch_core import stitch_images
  5. app = FastAPI(title="Image Stitcher API")
  6. class StitchRequest(BaseModel):
  7. images: List[str] = Field(..., min_length=2, description="Base64 编码的图片列表(至少 2 张)")
  8. direction: Literal["horizontal", "vertical", "grid"] = "horizontal"
  9. columns: int = Field(2, ge=1)
  10. spacing: int = Field(0, ge=0)
  11. background_color: str = "#FFFFFF"
  12. resize_mode: Literal["none", "fit_width", "fit_height"] = "none"
  13. class StitchResponse(BaseModel):
  14. image: str
  15. width: int
  16. height: int
  17. @app.get("/health")
  18. def health():
  19. return {"status": "ok"}
  20. @app.post("/stitch", response_model=StitchResponse)
  21. def stitch(req: StitchRequest):
  22. try:
  23. result = stitch_images(
  24. images=req.images,
  25. direction=req.direction,
  26. columns=req.columns,
  27. spacing=req.spacing,
  28. background_color=req.background_color,
  29. resize_mode=req.resize_mode
  30. )
  31. return result
  32. except ValueError as e:
  33. raise HTTPException(status_code=400, detail=str(e))
  34. except Exception as e:
  35. raise HTTPException(status_code=500, detail=f"拼接失败: {str(e)}")
  36. if __name__ == "__main__":
  37. import argparse
  38. import uvicorn
  39. parser = argparse.ArgumentParser()
  40. parser.add_argument("--port", type=int, default=8001)
  41. args = parser.parse_args()
  42. uvicorn.run(app, host="0.0.0.0", port=args.port)