jimeng_client.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. """即梦AI客户端 - 支持文生图和图生视频"""
  2. import asyncio
  3. import base64
  4. import time
  5. from typing import Optional, Dict, Any, List
  6. from datetime import datetime, timedelta
  7. import httpx
  8. from pydantic import BaseModel, Field
  9. class TaskCache:
  10. """任务状态缓存"""
  11. def __init__(self, ttl_hours: int = 24):
  12. self.cache: Dict[str, Dict[str, Any]] = {}
  13. self.ttl = timedelta(hours=ttl_hours)
  14. def set(self, task_id: str, data: Dict[str, Any]):
  15. self.cache[task_id] = {
  16. "data": data,
  17. "timestamp": datetime.now()
  18. }
  19. def get(self, task_id: str) -> Optional[Dict[str, Any]]:
  20. if task_id not in self.cache:
  21. return None
  22. entry = self.cache[task_id]
  23. if datetime.now() - entry["timestamp"] > self.ttl:
  24. del self.cache[task_id]
  25. return None
  26. return entry["data"]
  27. def cleanup(self):
  28. """清理过期缓存"""
  29. now = datetime.now()
  30. expired = [
  31. task_id for task_id, entry in self.cache.items()
  32. if now - entry["timestamp"] > self.ttl
  33. ]
  34. for task_id in expired:
  35. del self.cache[task_id]
  36. class JimengClient:
  37. """即梦AI客户端"""
  38. def __init__(
  39. self,
  40. api_key: Optional[str] = None,
  41. cookie: Optional[str] = None,
  42. base_url: str = "https://api.jimeng.ai"
  43. ):
  44. self.api_key = api_key
  45. self.cookie = cookie
  46. self.base_url = base_url.rstrip("/")
  47. self.cache = TaskCache()
  48. self.client = httpx.AsyncClient(timeout=30.0)
  49. def _get_headers(self) -> Dict[str, str]:
  50. """构建请求头"""
  51. headers = {
  52. "Content-Type": "application/json",
  53. "User-Agent": "JimengAI-Client/1.0"
  54. }
  55. if self.api_key:
  56. headers["Authorization"] = f"Bearer {self.api_key}"
  57. if self.cookie:
  58. headers["Cookie"] = self.cookie
  59. return headers
  60. async def text2image(
  61. self,
  62. prompt: str,
  63. negative_prompt: str = "",
  64. aspect_ratio: str = "1:1",
  65. image_count: int = 1,
  66. cfg_scale: float = 7.0,
  67. steps: int = 20,
  68. seed: int = -1
  69. ) -> Dict[str, Any]:
  70. """文生图 - Seendance 2.0"""
  71. payload = {
  72. "model": "seendance_2.0",
  73. "prompt": prompt,
  74. "negative_prompt": negative_prompt,
  75. "aspect_ratio": aspect_ratio,
  76. "num_images": image_count,
  77. "cfg_scale": cfg_scale,
  78. "steps": steps,
  79. "seed": seed if seed > 0 else None
  80. }
  81. try:
  82. response = await self.client.post(
  83. f"{self.base_url}/v1/text2image",
  84. json=payload,
  85. headers=self._get_headers()
  86. )
  87. response.raise_for_status()
  88. result = response.json()
  89. task_id = result.get("task_id") or result.get("id") or f"task_{int(time.time())}"
  90. task_data = {
  91. "task_id": task_id,
  92. "status": "processing",
  93. "progress": 0,
  94. "type": "text2image",
  95. "created_at": datetime.now().isoformat(),
  96. "estimated_time": steps * image_count * 2
  97. }
  98. self.cache.set(task_id, task_data)
  99. return task_data
  100. except httpx.HTTPStatusError as e:
  101. return {
  102. "task_id": f"error_{int(time.time())}",
  103. "status": "failed",
  104. "error": f"HTTP {e.response.status_code}: {e.response.text}"
  105. }
  106. except Exception as e:
  107. return {
  108. "task_id": f"error_{int(time.time())}",
  109. "status": "failed",
  110. "error": str(e)
  111. }
  112. async def image2video(
  113. self,
  114. image_url: Optional[str] = None,
  115. image_base64: Optional[str] = None,
  116. prompt: str = "",
  117. video_duration: int = 5,
  118. motion_strength: float = 0.5,
  119. start_frame: Optional[str] = None,
  120. end_frame: Optional[str] = None,
  121. seed: int = -1
  122. ) -> Dict[str, Any]:
  123. """图生视频 - Seedream Lite 5.0"""
  124. # 处理图片输入
  125. image_data = None
  126. if image_base64:
  127. image_data = image_base64
  128. elif image_url:
  129. try:
  130. img_response = await self.client.get(image_url)
  131. img_response.raise_for_status()
  132. image_data = base64.b64encode(img_response.content).decode()
  133. except Exception as e:
  134. return {
  135. "task_id": f"error_{int(time.time())}",
  136. "status": "failed",
  137. "error": f"Failed to fetch image: {str(e)}"
  138. }
  139. else:
  140. return {
  141. "task_id": f"error_{int(time.time())}",
  142. "status": "failed",
  143. "error": "Either image_url or image_base64 is required"
  144. }
  145. payload = {
  146. "model": "seedream_lite_5.0",
  147. "image": image_data,
  148. "prompt": prompt,
  149. "duration": video_duration,
  150. "motion_strength": motion_strength,
  151. "seed": seed if seed > 0 else None
  152. }
  153. if start_frame:
  154. payload["start_frame"] = start_frame
  155. if end_frame:
  156. payload["end_frame"] = end_frame
  157. try:
  158. response = await self.client.post(
  159. f"{self.base_url}/v1/image2video",
  160. json=payload,
  161. headers=self._get_headers()
  162. )
  163. response.raise_for_status()
  164. result = response.json()
  165. task_id = result.get("task_id") or result.get("id") or f"task_{int(time.time())}"
  166. task_data = {
  167. "task_id": task_id,
  168. "status": "processing",
  169. "progress": 0,
  170. "type": "image2video",
  171. "created_at": datetime.now().isoformat(),
  172. "estimated_time": video_duration * 10
  173. }
  174. self.cache.set(task_id, task_data)
  175. return task_data
  176. except httpx.HTTPStatusError as e:
  177. return {
  178. "task_id": f"error_{int(time.time())}",
  179. "status": "failed",
  180. "error": f"HTTP {e.response.status_code}: {e.response.text}"
  181. }
  182. except Exception as e:
  183. return {
  184. "task_id": f"error_{int(time.time())}",
  185. "status": "failed",
  186. "error": str(e)
  187. }
  188. async def query_status(self, task_id: str) -> Dict[str, Any]:
  189. """查询任务状态"""
  190. # 先检查缓存
  191. cached = self.cache.get(task_id)
  192. if cached and cached.get("status") == "completed":
  193. return cached
  194. try:
  195. response = await self.client.get(
  196. f"{self.base_url}/v1/tasks/{task_id}",
  197. headers=self._get_headers()
  198. )
  199. response.raise_for_status()
  200. result = response.json()
  201. task_data = {
  202. "task_id": task_id,
  203. "status": result.get("status", "processing"),
  204. "progress": result.get("progress", 0),
  205. }
  206. if result.get("status") == "completed":
  207. task_data["result"] = {
  208. "images": result.get("images", []),
  209. "videos": result.get("videos", []),
  210. "metadata": {
  211. "model": result.get("model"),
  212. "seed": result.get("seed"),
  213. "duration": result.get("duration")
  214. }
  215. }
  216. elif result.get("status") == "failed":
  217. task_data["error"] = result.get("error", "Unknown error")
  218. self.cache.set(task_id, task_data)
  219. return task_data
  220. except httpx.HTTPStatusError as e:
  221. if e.response.status_code == 404:
  222. return {
  223. "task_id": task_id,
  224. "status": "failed",
  225. "error": "Task not found"
  226. }
  227. return {
  228. "task_id": task_id,
  229. "status": "failed",
  230. "error": f"HTTP {e.response.status_code}: {e.response.text}"
  231. }
  232. except Exception as e:
  233. return {
  234. "task_id": task_id,
  235. "status": "failed",
  236. "error": str(e)
  237. }
  238. async def close(self):
  239. """关闭客户端"""
  240. await self.client.aclose()
  241. def cleanup_cache(self):
  242. """清理过期缓存"""
  243. self.cache.cleanup()