main.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. """即梦AI工具 - FastAPI接口"""
  2. import os
  3. from typing import Optional, Literal
  4. from contextlib import asynccontextmanager
  5. from fastapi import FastAPI, HTTPException
  6. from pydantic import BaseModel, Field
  7. from dotenv import load_dotenv
  8. from jimeng_client import JimengClient
  9. # 加载环境变量
  10. load_dotenv()
  11. # 全局客户端实例
  12. client: Optional[JimengClient] = None
  13. @asynccontextmanager
  14. async def lifespan(app: FastAPI):
  15. """应用生命周期管理"""
  16. global client
  17. api_key = os.getenv("JIMENG_API_KEY")
  18. cookie = os.getenv("JIMENG_COOKIE")
  19. base_url = os.getenv("JIMENG_BASE_URL", "https://api.jimeng.ai")
  20. client = JimengClient(api_key=api_key, cookie=cookie, base_url=base_url)
  21. yield
  22. if client:
  23. await client.close()
  24. app = FastAPI(
  25. title="即梦AI工具",
  26. description="支持文生图(Seendance 2.0)和图生视频(Seedream Lite 5.0)",
  27. version="1.0.0",
  28. lifespan=lifespan
  29. )
  30. class GenerateRequest(BaseModel):
  31. """生成请求模型"""
  32. action: Literal["text2image", "image2video", "query_status"] = Field(
  33. ...,
  34. description="操作类型"
  35. )
  36. # 通用参数
  37. prompt: Optional[str] = Field(None, description="正向提示词")
  38. negative_prompt: Optional[str] = Field("", description="负向提示词")
  39. seed: Optional[int] = Field(-1, description="随机种子")
  40. # 文生图参数
  41. model: Optional[str] = Field("seendance_2.0", description="模型选择")
  42. aspect_ratio: Optional[str] = Field("1:1", description="图片长宽比")
  43. image_count: Optional[int] = Field(1, ge=1, le=4, description="生成图片数量")
  44. cfg_scale: Optional[float] = Field(7.0, ge=1.0, le=20.0, description="创意强度")
  45. steps: Optional[int] = Field(20, ge=10, le=50, description="生成步数")
  46. # 图生视频参数
  47. image_url: Optional[str] = Field(None, description="参考图片URL")
  48. image_base64: Optional[str] = Field(None, description="参考图片Base64")
  49. video_duration: Optional[int] = Field(5, description="视频时长(秒)")
  50. motion_strength: Optional[float] = Field(0.5, ge=0.0, le=1.0, description="运动强度")
  51. start_frame: Optional[str] = Field(None, description="首帧图片")
  52. end_frame: Optional[str] = Field(None, description="尾帧图片")
  53. # 查询参数
  54. task_id: Optional[str] = Field(None, description="任务ID")
  55. # 认证参数
  56. cookie: Optional[str] = Field(None, description="认证Cookie")
  57. api_key: Optional[str] = Field(None, description="API密钥")
  58. class GenerateResponse(BaseModel):
  59. """生成响应模型"""
  60. task_id: str
  61. status: Literal["pending", "processing", "completed", "failed"]
  62. progress: Optional[float] = None
  63. result: Optional[dict] = None
  64. error: Optional[str] = None
  65. estimated_time: Optional[int] = None
  66. @app.get("/health")
  67. async def health_check():
  68. """健康检查"""
  69. return {
  70. "status": "healthy",
  71. "service": "jimeng_ai",
  72. "version": "1.0.0"
  73. }
  74. @app.post("/generate", response_model=GenerateResponse)
  75. async def generate(request: GenerateRequest):
  76. """创建生成任务"""
  77. if not client:
  78. raise HTTPException(status_code=500, detail="Client not initialized")
  79. # 使用请求中的认证信息(如果提供)
  80. active_client = client
  81. if request.api_key or request.cookie:
  82. active_client = JimengClient(
  83. api_key=request.api_key or client.api_key,
  84. cookie=request.cookie or client.cookie,
  85. base_url=client.base_url
  86. )
  87. try:
  88. if request.action == "text2image":
  89. if not request.prompt:
  90. raise HTTPException(status_code=400, detail="prompt is required for text2image")
  91. result = await active_client.text2image(
  92. prompt=request.prompt,
  93. negative_prompt=request.negative_prompt or "",
  94. aspect_ratio=request.aspect_ratio or "1:1",
  95. image_count=request.image_count or 1,
  96. cfg_scale=request.cfg_scale or 7.0,
  97. steps=request.steps or 20,
  98. seed=request.seed or -1
  99. )
  100. elif request.action == "image2video":
  101. if not request.image_url and not request.image_base64:
  102. raise HTTPException(
  103. status_code=400,
  104. detail="Either image_url or image_base64 is required for image2video"
  105. )
  106. result = await active_client.image2video(
  107. image_url=request.image_url,
  108. image_base64=request.image_base64,
  109. prompt=request.prompt or "",
  110. video_duration=request.video_duration or 5,
  111. motion_strength=request.motion_strength or 0.5,
  112. start_frame=request.start_frame,
  113. end_frame=request.end_frame,
  114. seed=request.seed or -1
  115. )
  116. elif request.action == "query_status":
  117. if not request.task_id:
  118. raise HTTPException(status_code=400, detail="task_id is required for query_status")
  119. result = await active_client.query_status(request.task_id)
  120. else:
  121. raise HTTPException(status_code=400, detail=f"Unknown action: {request.action}")
  122. return GenerateResponse(**result)
  123. except Exception as e:
  124. raise HTTPException(status_code=500, detail=str(e))
  125. finally:
  126. if active_client != client:
  127. await active_client.close()
  128. @app.get("/status/{task_id}", response_model=GenerateResponse)
  129. async def get_status(task_id: str):
  130. """查询任务状态"""
  131. if not client:
  132. raise HTTPException(status_code=500, detail="Client not initialized")
  133. try:
  134. result = await client.query_status(task_id)
  135. return GenerateResponse(**result)
  136. except Exception as e:
  137. raise HTTPException(status_code=500, detail=str(e))
  138. @app.post("/cleanup")
  139. async def cleanup_cache():
  140. """清理过期缓存"""
  141. if not client:
  142. raise HTTPException(status_code=500, detail="Client not initialized")
  143. client.cleanup_cache()
  144. return {"status": "success", "message": "Cache cleaned up"}
  145. if __name__ == "__main__":
  146. import uvicorn
  147. import argparse
  148. parser = argparse.ArgumentParser(description="即梦AI工具服务")
  149. parser.add_argument("--port", type=int, default=8000, help="服务端口")
  150. parser.add_argument("--host", type=str, default="0.0.0.0", help="服务地址")
  151. args = parser.parse_args()
  152. uvicorn.run(app, host=args.host, port=args.port)