"""即梦AI工具 - FastAPI接口""" import os from typing import Optional, Literal from contextlib import asynccontextmanager from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from dotenv import load_dotenv from jimeng_client import JimengClient # 加载环境变量 load_dotenv() # 全局客户端实例 client: Optional[JimengClient] = None @asynccontextmanager async def lifespan(app: FastAPI): """应用生命周期管理""" global client api_key = os.getenv("JIMENG_API_KEY") cookie = os.getenv("JIMENG_COOKIE") base_url = os.getenv("JIMENG_BASE_URL", "https://api.jimeng.ai") client = JimengClient(api_key=api_key, cookie=cookie, base_url=base_url) yield if client: await client.close() app = FastAPI( title="即梦AI工具", description="支持文生图(Seendance 2.0)和图生视频(Seedream Lite 5.0)", version="1.0.0", lifespan=lifespan ) class GenerateRequest(BaseModel): """生成请求模型""" action: Literal["text2image", "image2video", "query_status"] = Field( ..., description="操作类型" ) # 通用参数 prompt: Optional[str] = Field(None, description="正向提示词") negative_prompt: Optional[str] = Field("", description="负向提示词") seed: Optional[int] = Field(-1, description="随机种子") # 文生图参数 model: Optional[str] = Field("seendance_2.0", description="模型选择") aspect_ratio: Optional[str] = Field("1:1", description="图片长宽比") image_count: Optional[int] = Field(1, ge=1, le=4, description="生成图片数量") cfg_scale: Optional[float] = Field(7.0, ge=1.0, le=20.0, description="创意强度") steps: Optional[int] = Field(20, ge=10, le=50, description="生成步数") # 图生视频参数 image_url: Optional[str] = Field(None, description="参考图片URL") image_base64: Optional[str] = Field(None, description="参考图片Base64") video_duration: Optional[int] = Field(5, description="视频时长(秒)") motion_strength: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="运动强度") start_frame: Optional[str] = Field(None, description="首帧图片") end_frame: Optional[str] = Field(None, description="尾帧图片") # 查询参数 task_id: Optional[str] = Field(None, description="任务ID") # 认证参数 cookie: Optional[str] = Field(None, description="认证Cookie") api_key: Optional[str] = Field(None, description="API密钥") class GenerateResponse(BaseModel): """生成响应模型""" task_id: str status: Literal["pending", "processing", "completed", "failed"] progress: Optional[float] = None result: Optional[dict] = None error: Optional[str] = None estimated_time: Optional[int] = None @app.get("/health") async def health_check(): """健康检查""" return { "status": "healthy", "service": "jimeng_ai", "version": "1.0.0" } @app.post("/generate", response_model=GenerateResponse) async def generate(request: GenerateRequest): """创建生成任务""" if not client: raise HTTPException(status_code=500, detail="Client not initialized") # 使用请求中的认证信息(如果提供) active_client = client if request.api_key or request.cookie: active_client = JimengClient( api_key=request.api_key or client.api_key, cookie=request.cookie or client.cookie, base_url=client.base_url ) try: if request.action == "text2image": if not request.prompt: raise HTTPException(status_code=400, detail="prompt is required for text2image") result = await active_client.text2image( prompt=request.prompt, negative_prompt=request.negative_prompt or "", aspect_ratio=request.aspect_ratio or "1:1", image_count=request.image_count or 1, cfg_scale=request.cfg_scale or 7.0, steps=request.steps or 20, seed=request.seed or -1 ) elif request.action == "image2video": if not request.image_url and not request.image_base64: raise HTTPException( status_code=400, detail="Either image_url or image_base64 is required for image2video" ) result = await active_client.image2video( image_url=request.image_url, image_base64=request.image_base64, prompt=request.prompt or "", video_duration=request.video_duration or 5, motion_strength=request.motion_strength or 0.5, start_frame=request.start_frame, end_frame=request.end_frame, seed=request.seed or -1 ) elif request.action == "query_status": if not request.task_id: raise HTTPException(status_code=400, detail="task_id is required for query_status") result = await active_client.query_status(request.task_id) else: raise HTTPException(status_code=400, detail=f"Unknown action: {request.action}") return GenerateResponse(**result) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) finally: if active_client != client: await active_client.close() @app.get("/status/{task_id}", response_model=GenerateResponse) async def get_status(task_id: str): """查询任务状态""" if not client: raise HTTPException(status_code=500, detail="Client not initialized") try: result = await client.query_status(task_id) return GenerateResponse(**result) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/cleanup") async def cleanup_cache(): """清理过期缓存""" if not client: raise HTTPException(status_code=500, detail="Client not initialized") client.cleanup_cache() return {"status": "success", "message": "Cache cleaned up"} if __name__ == "__main__": import uvicorn import argparse parser = argparse.ArgumentParser(description="即梦AI工具服务") parser.add_argument("--port", type=int, default=8000, help="服务端口") parser.add_argument("--host", type=str, default="0.0.0.0", help="服务地址") args = parser.parse_args() uvicorn.run(app, host=args.host, port=args.port)