| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- from fastapi import FastAPI, HTTPException
- from pydantic import BaseModel, Field
- from typing import List, Literal
- from stitch_core import stitch_images
- app = FastAPI(title="Image Stitcher API")
- class StitchRequest(BaseModel):
- images: List[str] = Field(..., min_length=2, description="Base64 编码的图片列表(至少 2 张)")
- direction: Literal["horizontal", "vertical", "grid"] = "horizontal"
- columns: int = Field(2, ge=1)
- spacing: int = Field(0, ge=0)
- background_color: str = "#FFFFFF"
- resize_mode: Literal["none", "fit_width", "fit_height"] = "none"
- import sys
- import os
- # Add liblibai_controlnet to path so we can reuse its OSS infrastructure
- liblib_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../liblibai_controlnet"))
- if liblib_dir not in sys.path:
- sys.path.insert(0, liblib_dir)
- try:
- from liblibai_client import LibLibAIClient
- except ImportError:
- LibLibAIClient = None
- class StitchResponse(BaseModel):
- url: str = Field(..., description="The stitched image URL")
- width: int
- height: int
- @app.get("/health")
- def health():
- return {"status": "ok"}
- @app.post("/stitch", response_model=StitchResponse)
- def stitch(req: StitchRequest):
- try:
- result = stitch_images(
- images=req.images,
- direction=req.direction,
- columns=req.columns,
- spacing=req.spacing,
- background_color=req.background_color,
- resize_mode=req.resize_mode
- )
-
- # 将生成的 base64 图像传给云端对象存储
- if not LibLibAIClient:
- raise RuntimeError("LibLibAIClient not found. Cannot upload image.")
- client = LibLibAIClient()
- url = client.upload_base64_image(result["image"])
-
- return {
- "url": url,
- "width": result["width"],
- "height": result["height"]
- }
- except ValueError as e:
- raise HTTPException(status_code=400, detail=str(e))
- except Exception as e:
- raise HTTPException(status_code=500, detail=f"拼接失败: {str(e)}")
- if __name__ == "__main__":
- import argparse
- import uvicorn
- parser = argparse.ArgumentParser()
- parser.add_argument("--port", type=int, default=8001)
- args = parser.parse_args()
- uvicorn.run(app, host="0.0.0.0", port=args.port)
|