| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195 |
- """即梦AI工具 - FastAPI接口"""
- import os
- from typing import Optional, Literal
- from contextlib import asynccontextmanager
- from fastapi import FastAPI, HTTPException
- from pydantic import BaseModel, Field
- from dotenv import load_dotenv
- from jimeng_client import JimengClient
- # 加载环境变量
- load_dotenv()
- # 全局客户端实例
- client: Optional[JimengClient] = None
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- """应用生命周期管理"""
- global client
-
- api_key = os.getenv("JIMENG_API_KEY")
- cookie = os.getenv("JIMENG_COOKIE")
- base_url = os.getenv("JIMENG_BASE_URL", "https://api.jimeng.ai")
-
- client = JimengClient(api_key=api_key, cookie=cookie, base_url=base_url)
-
- yield
-
- if client:
- await client.close()
- app = FastAPI(
- title="即梦AI工具",
- description="支持文生图(Seendance 2.0)和图生视频(Seedream Lite 5.0)",
- version="1.0.0",
- lifespan=lifespan
- )
- class GenerateRequest(BaseModel):
- """生成请求模型"""
- action: Literal["text2image", "image2video", "query_status"] = Field(
- ...,
- description="操作类型"
- )
-
- # 通用参数
- prompt: Optional[str] = Field(None, description="正向提示词")
- negative_prompt: Optional[str] = Field("", description="负向提示词")
- seed: Optional[int] = Field(-1, description="随机种子")
-
- # 文生图参数
- model: Optional[str] = Field("seendance_2.0", description="模型选择")
- aspect_ratio: Optional[str] = Field("1:1", description="图片长宽比")
- image_count: Optional[int] = Field(1, ge=1, le=4, description="生成图片数量")
- cfg_scale: Optional[float] = Field(7.0, ge=1.0, le=20.0, description="创意强度")
- steps: Optional[int] = Field(20, ge=10, le=50, description="生成步数")
-
- # 图生视频参数
- image_url: Optional[str] = Field(None, description="参考图片URL")
- image_base64: Optional[str] = Field(None, description="参考图片Base64")
- video_duration: Optional[int] = Field(5, description="视频时长(秒)")
- motion_strength: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="运动强度")
- start_frame: Optional[str] = Field(None, description="首帧图片")
- end_frame: Optional[str] = Field(None, description="尾帧图片")
-
- # 查询参数
- task_id: Optional[str] = Field(None, description="任务ID")
-
- # 认证参数
- cookie: Optional[str] = Field(None, description="认证Cookie")
- api_key: Optional[str] = Field(None, description="API密钥")
- class GenerateResponse(BaseModel):
- """生成响应模型"""
- task_id: str
- status: Literal["pending", "processing", "completed", "failed"]
- progress: Optional[float] = None
- result: Optional[dict] = None
- error: Optional[str] = None
- estimated_time: Optional[int] = None
- @app.get("/health")
- async def health_check():
- """健康检查"""
- return {
- "status": "healthy",
- "service": "jimeng_ai",
- "version": "1.0.0"
- }
- @app.post("/generate", response_model=GenerateResponse)
- async def generate(request: GenerateRequest):
- """创建生成任务"""
- if not client:
- raise HTTPException(status_code=500, detail="Client not initialized")
-
- # 使用请求中的认证信息(如果提供)
- active_client = client
- if request.api_key or request.cookie:
- active_client = JimengClient(
- api_key=request.api_key or client.api_key,
- cookie=request.cookie or client.cookie,
- base_url=client.base_url
- )
-
- try:
- if request.action == "text2image":
- if not request.prompt:
- raise HTTPException(status_code=400, detail="prompt is required for text2image")
-
- result = await active_client.text2image(
- prompt=request.prompt,
- negative_prompt=request.negative_prompt or "",
- aspect_ratio=request.aspect_ratio or "1:1",
- image_count=request.image_count or 1,
- cfg_scale=request.cfg_scale or 7.0,
- steps=request.steps or 20,
- seed=request.seed or -1
- )
-
- elif request.action == "image2video":
- if not request.image_url and not request.image_base64:
- raise HTTPException(
- status_code=400,
- detail="Either image_url or image_base64 is required for image2video"
- )
-
- result = await active_client.image2video(
- image_url=request.image_url,
- image_base64=request.image_base64,
- prompt=request.prompt or "",
- video_duration=request.video_duration or 5,
- motion_strength=request.motion_strength or 0.5,
- start_frame=request.start_frame,
- end_frame=request.end_frame,
- seed=request.seed or -1
- )
-
- elif request.action == "query_status":
- if not request.task_id:
- raise HTTPException(status_code=400, detail="task_id is required for query_status")
-
- result = await active_client.query_status(request.task_id)
-
- else:
- raise HTTPException(status_code=400, detail=f"Unknown action: {request.action}")
-
- return GenerateResponse(**result)
-
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
- finally:
- if active_client != client:
- await active_client.close()
- @app.get("/status/{task_id}", response_model=GenerateResponse)
- async def get_status(task_id: str):
- """查询任务状态"""
- if not client:
- raise HTTPException(status_code=500, detail="Client not initialized")
-
- try:
- result = await client.query_status(task_id)
- return GenerateResponse(**result)
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
- @app.post("/cleanup")
- async def cleanup_cache():
- """清理过期缓存"""
- if not client:
- raise HTTPException(status_code=500, detail="Client not initialized")
-
- client.cleanup_cache()
- return {"status": "success", "message": "Cache cleaned up"}
- if __name__ == "__main__":
- import uvicorn
- import argparse
-
- parser = argparse.ArgumentParser(description="即梦AI工具服务")
- parser.add_argument("--port", type=int, default=8000, help="服务端口")
- parser.add_argument("--host", type=str, default="0.0.0.0", help="服务地址")
- args = parser.parse_args()
-
- uvicorn.run(app, host=args.host, port=args.port)
|