"""即梦AI客户端 - 支持文生图和图生视频""" import asyncio import base64 import time from typing import Optional, Dict, Any, List from datetime import datetime, timedelta import httpx from pydantic import BaseModel, Field class TaskCache: """任务状态缓存""" def __init__(self, ttl_hours: int = 24): self.cache: Dict[str, Dict[str, Any]] = {} self.ttl = timedelta(hours=ttl_hours) def set(self, task_id: str, data: Dict[str, Any]): self.cache[task_id] = { "data": data, "timestamp": datetime.now() } def get(self, task_id: str) -> Optional[Dict[str, Any]]: if task_id not in self.cache: return None entry = self.cache[task_id] if datetime.now() - entry["timestamp"] > self.ttl: del self.cache[task_id] return None return entry["data"] def cleanup(self): """清理过期缓存""" now = datetime.now() expired = [ task_id for task_id, entry in self.cache.items() if now - entry["timestamp"] > self.ttl ] for task_id in expired: del self.cache[task_id] class JimengClient: """即梦AI客户端""" def __init__( self, api_key: Optional[str] = None, cookie: Optional[str] = None, base_url: str = "https://api.jimeng.ai" ): self.api_key = api_key self.cookie = cookie self.base_url = base_url.rstrip("/") self.cache = TaskCache() self.client = httpx.AsyncClient(timeout=30.0) def _get_headers(self) -> Dict[str, str]: """构建请求头""" headers = { "Content-Type": "application/json", "User-Agent": "JimengAI-Client/1.0" } if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" if self.cookie: headers["Cookie"] = self.cookie return headers async def text2image( self, prompt: str, negative_prompt: str = "", aspect_ratio: str = "1:1", image_count: int = 1, cfg_scale: float = 7.0, steps: int = 20, seed: int = -1 ) -> Dict[str, Any]: """文生图 - Seendance 2.0""" payload = { "model": "seendance_2.0", "prompt": prompt, "negative_prompt": negative_prompt, "aspect_ratio": aspect_ratio, "num_images": image_count, "cfg_scale": cfg_scale, "steps": steps, "seed": seed if seed > 0 else None } try: response = await self.client.post( f"{self.base_url}/v1/text2image", json=payload, headers=self._get_headers() ) response.raise_for_status() result = response.json() task_id = result.get("task_id") or result.get("id") or f"task_{int(time.time())}" task_data = { "task_id": task_id, "status": "processing", "progress": 0, "type": "text2image", "created_at": datetime.now().isoformat(), "estimated_time": steps * image_count * 2 } self.cache.set(task_id, task_data) return task_data except httpx.HTTPStatusError as e: return { "task_id": f"error_{int(time.time())}", "status": "failed", "error": f"HTTP {e.response.status_code}: {e.response.text}" } except Exception as e: return { "task_id": f"error_{int(time.time())}", "status": "failed", "error": str(e) } async def image2video( self, image_url: Optional[str] = None, image_base64: Optional[str] = None, prompt: str = "", video_duration: int = 5, motion_strength: float = 0.5, start_frame: Optional[str] = None, end_frame: Optional[str] = None, seed: int = -1 ) -> Dict[str, Any]: """图生视频 - Seedream Lite 5.0""" # 处理图片输入 image_data = None if image_base64: image_data = image_base64 elif image_url: try: img_response = await self.client.get(image_url) img_response.raise_for_status() image_data = base64.b64encode(img_response.content).decode() except Exception as e: return { "task_id": f"error_{int(time.time())}", "status": "failed", "error": f"Failed to fetch image: {str(e)}" } else: return { "task_id": f"error_{int(time.time())}", "status": "failed", "error": "Either image_url or image_base64 is required" } payload = { "model": "seedream_lite_5.0", "image": image_data, "prompt": prompt, "duration": video_duration, "motion_strength": motion_strength, "seed": seed if seed > 0 else None } if start_frame: payload["start_frame"] = start_frame if end_frame: payload["end_frame"] = end_frame try: response = await self.client.post( f"{self.base_url}/v1/image2video", json=payload, headers=self._get_headers() ) response.raise_for_status() result = response.json() task_id = result.get("task_id") or result.get("id") or f"task_{int(time.time())}" task_data = { "task_id": task_id, "status": "processing", "progress": 0, "type": "image2video", "created_at": datetime.now().isoformat(), "estimated_time": video_duration * 10 } self.cache.set(task_id, task_data) return task_data except httpx.HTTPStatusError as e: return { "task_id": f"error_{int(time.time())}", "status": "failed", "error": f"HTTP {e.response.status_code}: {e.response.text}" } except Exception as e: return { "task_id": f"error_{int(time.time())}", "status": "failed", "error": str(e) } async def query_status(self, task_id: str) -> Dict[str, Any]: """查询任务状态""" # 先检查缓存 cached = self.cache.get(task_id) if cached and cached.get("status") == "completed": return cached try: response = await self.client.get( f"{self.base_url}/v1/tasks/{task_id}", headers=self._get_headers() ) response.raise_for_status() result = response.json() task_data = { "task_id": task_id, "status": result.get("status", "processing"), "progress": result.get("progress", 0), } if result.get("status") == "completed": task_data["result"] = { "images": result.get("images", []), "videos": result.get("videos", []), "metadata": { "model": result.get("model"), "seed": result.get("seed"), "duration": result.get("duration") } } elif result.get("status") == "failed": task_data["error"] = result.get("error", "Unknown error") self.cache.set(task_id, task_data) return task_data except httpx.HTTPStatusError as e: if e.response.status_code == 404: return { "task_id": task_id, "status": "failed", "error": "Task not found" } return { "task_id": task_id, "status": "failed", "error": f"HTTP {e.response.status_code}: {e.response.text}" } except Exception as e: return { "task_id": task_id, "status": "failed", "error": str(e) } async def close(self): """关闭客户端""" await self.client.aclose() def cleanup_cache(self): """清理过期缓存""" self.cache.cleanup()