main.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. """
  2. 快手可灵AI工具 FastAPI 服务
  3. 提供统一的HTTP接口用于AI视频生成、AI图片生成、AI对口型等功能
  4. """
  5. import asyncio
  6. from fastapi import FastAPI, HTTPException
  7. from pydantic import BaseModel, Field
  8. from typing import Optional, Dict, Any, List
  9. from enum import Enum
  10. import uvicorn
  11. from kling_client import KlingClient, BizType, TaskStatus
  12. app = FastAPI(
  13. title="快手可灵AI工具",
  14. description="支持AI视频生成、AI图片生成、AI对口型等功能",
  15. version="1.0.0"
  16. )
  17. class GenerateRequest(BaseModel):
  18. """生成请求模型"""
  19. biz_type: str = Field(..., description="业务类型: aiImage, aiVideo, aiLipSync")
  20. action: Optional[str] = Field(None, description="动作类型")
  21. prompt: Optional[str] = Field(None, description="生成内容的提示词")
  22. negative_prompt: Optional[str] = Field(None, description="不希望呈现的内容")
  23. cfg: str = Field("50", description="创意想象力与创意相关性比例")
  24. mode: Optional[str] = Field(None, description="生成模式: text2video, audio2video")
  25. image_url: Optional[str] = Field(None, description="参考图片地址")
  26. aspect_ratio: str = Field("16:9", description="长宽比: 9:16, 16:9, 1:1")
  27. task_id: Optional[str] = Field(None, description="查询任务状态时使用")
  28. cookie: Optional[str] = Field(None, description="认证Cookie")
  29. version: Optional[str] = Field(None, description="模型版本")
  30. image_count: int = Field(4, description="生成图片数量(1-4)")
  31. add_audio: bool = Field(False, description="是否自动添加音频")
  32. start_frame_image: Optional[str] = Field(None, description="首帧图片URL")
  33. end_frame_image: Optional[str] = Field(None, description="尾帧图片URL")
  34. video_id: Optional[str] = Field(None, description="视频ID(对口型用)")
  35. video_url: Optional[str] = Field(None, description="视频URL(对口型用)")
  36. text: Optional[str] = Field(None, description="对口型文本内容")
  37. voice_id: Optional[str] = Field(None, description="音色ID")
  38. voice_language: str = Field("zh", description="音色语种: zh, en")
  39. voice_speed: float = Field(1.0, description="语速")
  40. audio_type: Optional[str] = Field(None, description="音频类型: file, url")
  41. audio_file: Optional[str] = Field(None, description="音频文件路径")
  42. audio_url: Optional[str] = Field(None, description="音频URL")
  43. class GenerateResponse(BaseModel):
  44. """生成响应模型"""
  45. task_id: Optional[str] = Field(None, description="任务ID")
  46. status: Optional[str] = Field(None, description="任务状态: process, finished, failed")
  47. result: Optional[Dict[str, Any]] = Field(None, description="生成结果")
  48. error: Optional[str] = Field(None, description="错误信息")
  49. class StatusResponse(BaseModel):
  50. """状态查询响应模型"""
  51. task_id: str = Field(..., description="任务ID")
  52. status: str = Field(..., description="任务状态: process, finished, failed")
  53. result: Optional[Dict[str, Any]] = Field(None, description="生成结果")
  54. error: Optional[str] = Field(None, description="错误信息")
  55. @app.get("/")
  56. async def root():
  57. """根路径"""
  58. return {
  59. "service": "快手可灵AI工具",
  60. "version": "1.0.0",
  61. "endpoints": {
  62. "generate": "POST /generate - 创建生成任务",
  63. "status": "GET /status/{task_id} - 查询任务状态"
  64. }
  65. }
  66. @app.post("/generate", response_model=GenerateResponse)
  67. async def generate(request: GenerateRequest):
  68. """
  69. 创建生成任务
  70. 支持三种业务类型:
  71. - aiImage: AI图片生成
  72. - aiVideo: AI视频生成
  73. - aiLipSync: AI对口型
  74. """
  75. try:
  76. client = KlingClient(cookie=request.cookie)
  77. # 根据业务类型调用不同的API
  78. if request.biz_type == "aiImage":
  79. if not request.prompt:
  80. raise HTTPException(status_code=400, detail="prompt is required for aiImage")
  81. result = await client.create_image_task(
  82. prompt=request.prompt,
  83. negative_prompt=request.negative_prompt,
  84. cfg=request.cfg,
  85. aspect_ratio=request.aspect_ratio,
  86. image_count=request.image_count,
  87. version=request.version
  88. )
  89. elif request.biz_type == "aiVideo":
  90. if not request.prompt:
  91. raise HTTPException(status_code=400, detail="prompt is required for aiVideo")
  92. result = await client.create_video_task(
  93. prompt=request.prompt,
  94. mode=request.mode or "text2video",
  95. image_url=request.image_url,
  96. aspect_ratio=request.aspect_ratio,
  97. add_audio=request.add_audio,
  98. start_frame_image=request.start_frame_image,
  99. end_frame_image=request.end_frame_image,
  100. version=request.version
  101. )
  102. elif request.biz_type == "aiLipSync":
  103. result = await client.create_lipsync_task(
  104. video_id=request.video_id,
  105. video_url=request.video_url,
  106. mode=request.mode or "text2video",
  107. text=request.text,
  108. voice_id=request.voice_id,
  109. voice_language=request.voice_language,
  110. voice_speed=request.voice_speed,
  111. audio_type=request.audio_type,
  112. audio_file=request.audio_file,
  113. audio_url=request.audio_url
  114. )
  115. else:
  116. raise HTTPException(
  117. status_code=400,
  118. detail=f"Invalid biz_type: {request.biz_type}. Must be one of: aiImage, aiVideo, aiLipSync"
  119. )
  120. # 解析响应
  121. task_id = result.get("task_id") or result.get("data", {}).get("task_id")
  122. status = result.get("status", "process")
  123. return GenerateResponse(
  124. task_id=task_id,
  125. status=status,
  126. result=result.get("result") or result.get("data"),
  127. error=result.get("error")
  128. )
  129. except HTTPException:
  130. raise
  131. except Exception as e:
  132. raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
  133. @app.get("/status/{task_id}", response_model=StatusResponse)
  134. async def get_status(
  135. task_id: str,
  136. biz_type: str = "aiImage",
  137. cookie: Optional[str] = None
  138. ):
  139. """
  140. 查询任务状态
  141. Args:
  142. task_id: 任务ID
  143. biz_type: 业务类型 (aiImage, aiVideo, aiLipSync)
  144. cookie: 认证Cookie
  145. """
  146. try:
  147. # 验证biz_type
  148. if biz_type not in ["aiImage", "aiVideo", "aiLipSync"]:
  149. raise HTTPException(
  150. status_code=400,
  151. detail=f"Invalid biz_type: {biz_type}. Must be one of: aiImage, aiVideo, aiLipSync"
  152. )
  153. client = KlingClient(cookie=cookie)
  154. biz_type_enum = BizType(biz_type)
  155. result = await client.query_task_status(task_id, biz_type_enum)
  156. return StatusResponse(
  157. task_id=task_id,
  158. status=result.get("status", "process"),
  159. result=result.get("result") or result.get("data"),
  160. error=result.get("error")
  161. )
  162. except HTTPException:
  163. raise
  164. except Exception as e:
  165. raise HTTPException(status_code=500, detail=f"Internal error: {str(e)}")
  166. @app.get("/health")
  167. async def health():
  168. """健康检查"""
  169. return {"status": "healthy"}
  170. if __name__ == "__main__":
  171. import sys
  172. port = int(sys.argv[1]) if len(sys.argv) > 1 else 8000
  173. uvicorn.run(app, host="0.0.0.0", port=port)