| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280 |
- """即梦AI客户端 - 支持文生图和图生视频"""
- import asyncio
- import base64
- import time
- from typing import Optional, Dict, Any, List
- from datetime import datetime, timedelta
- import httpx
- from pydantic import BaseModel, Field
- class TaskCache:
- """任务状态缓存"""
- def __init__(self, ttl_hours: int = 24):
- self.cache: Dict[str, Dict[str, Any]] = {}
- self.ttl = timedelta(hours=ttl_hours)
-
- def set(self, task_id: str, data: Dict[str, Any]):
- self.cache[task_id] = {
- "data": data,
- "timestamp": datetime.now()
- }
-
- def get(self, task_id: str) -> Optional[Dict[str, Any]]:
- if task_id not in self.cache:
- return None
-
- entry = self.cache[task_id]
- if datetime.now() - entry["timestamp"] > self.ttl:
- del self.cache[task_id]
- return None
-
- return entry["data"]
-
- def cleanup(self):
- """清理过期缓存"""
- now = datetime.now()
- expired = [
- task_id for task_id, entry in self.cache.items()
- if now - entry["timestamp"] > self.ttl
- ]
- for task_id in expired:
- del self.cache[task_id]
- class JimengClient:
- """即梦AI客户端"""
-
- def __init__(
- self,
- api_key: Optional[str] = None,
- cookie: Optional[str] = None,
- base_url: str = "https://api.jimeng.ai"
- ):
- self.api_key = api_key
- self.cookie = cookie
- self.base_url = base_url.rstrip("/")
- self.cache = TaskCache()
- self.client = httpx.AsyncClient(timeout=30.0)
-
- def _get_headers(self) -> Dict[str, str]:
- """构建请求头"""
- headers = {
- "Content-Type": "application/json",
- "User-Agent": "JimengAI-Client/1.0"
- }
-
- if self.api_key:
- headers["Authorization"] = f"Bearer {self.api_key}"
-
- if self.cookie:
- headers["Cookie"] = self.cookie
-
- return headers
-
- async def text2image(
- self,
- prompt: str,
- negative_prompt: str = "",
- aspect_ratio: str = "1:1",
- image_count: int = 1,
- cfg_scale: float = 7.0,
- steps: int = 20,
- seed: int = -1
- ) -> Dict[str, Any]:
- """文生图 - Seendance 2.0"""
- payload = {
- "model": "seendance_2.0",
- "prompt": prompt,
- "negative_prompt": negative_prompt,
- "aspect_ratio": aspect_ratio,
- "num_images": image_count,
- "cfg_scale": cfg_scale,
- "steps": steps,
- "seed": seed if seed > 0 else None
- }
-
- try:
- response = await self.client.post(
- f"{self.base_url}/v1/text2image",
- json=payload,
- headers=self._get_headers()
- )
- response.raise_for_status()
- result = response.json()
-
- task_id = result.get("task_id") or result.get("id") or f"task_{int(time.time())}"
-
- task_data = {
- "task_id": task_id,
- "status": "processing",
- "progress": 0,
- "type": "text2image",
- "created_at": datetime.now().isoformat(),
- "estimated_time": steps * image_count * 2
- }
-
- self.cache.set(task_id, task_data)
- return task_data
-
- except httpx.HTTPStatusError as e:
- return {
- "task_id": f"error_{int(time.time())}",
- "status": "failed",
- "error": f"HTTP {e.response.status_code}: {e.response.text}"
- }
- except Exception as e:
- return {
- "task_id": f"error_{int(time.time())}",
- "status": "failed",
- "error": str(e)
- }
-
- async def image2video(
- self,
- image_url: Optional[str] = None,
- image_base64: Optional[str] = None,
- prompt: str = "",
- video_duration: int = 5,
- motion_strength: float = 0.5,
- start_frame: Optional[str] = None,
- end_frame: Optional[str] = None,
- seed: int = -1
- ) -> Dict[str, Any]:
- """图生视频 - Seedream Lite 5.0"""
-
- # 处理图片输入
- image_data = None
- if image_base64:
- image_data = image_base64
- elif image_url:
- try:
- img_response = await self.client.get(image_url)
- img_response.raise_for_status()
- image_data = base64.b64encode(img_response.content).decode()
- except Exception as e:
- return {
- "task_id": f"error_{int(time.time())}",
- "status": "failed",
- "error": f"Failed to fetch image: {str(e)}"
- }
- else:
- return {
- "task_id": f"error_{int(time.time())}",
- "status": "failed",
- "error": "Either image_url or image_base64 is required"
- }
-
- payload = {
- "model": "seedream_lite_5.0",
- "image": image_data,
- "prompt": prompt,
- "duration": video_duration,
- "motion_strength": motion_strength,
- "seed": seed if seed > 0 else None
- }
-
- if start_frame:
- payload["start_frame"] = start_frame
- if end_frame:
- payload["end_frame"] = end_frame
-
- try:
- response = await self.client.post(
- f"{self.base_url}/v1/image2video",
- json=payload,
- headers=self._get_headers()
- )
- response.raise_for_status()
- result = response.json()
-
- task_id = result.get("task_id") or result.get("id") or f"task_{int(time.time())}"
-
- task_data = {
- "task_id": task_id,
- "status": "processing",
- "progress": 0,
- "type": "image2video",
- "created_at": datetime.now().isoformat(),
- "estimated_time": video_duration * 10
- }
-
- self.cache.set(task_id, task_data)
- return task_data
-
- except httpx.HTTPStatusError as e:
- return {
- "task_id": f"error_{int(time.time())}",
- "status": "failed",
- "error": f"HTTP {e.response.status_code}: {e.response.text}"
- }
- except Exception as e:
- return {
- "task_id": f"error_{int(time.time())}",
- "status": "failed",
- "error": str(e)
- }
-
- async def query_status(self, task_id: str) -> Dict[str, Any]:
- """查询任务状态"""
- # 先检查缓存
- cached = self.cache.get(task_id)
- if cached and cached.get("status") == "completed":
- return cached
-
- try:
- response = await self.client.get(
- f"{self.base_url}/v1/tasks/{task_id}",
- headers=self._get_headers()
- )
- response.raise_for_status()
- result = response.json()
-
- task_data = {
- "task_id": task_id,
- "status": result.get("status", "processing"),
- "progress": result.get("progress", 0),
- }
-
- if result.get("status") == "completed":
- task_data["result"] = {
- "images": result.get("images", []),
- "videos": result.get("videos", []),
- "metadata": {
- "model": result.get("model"),
- "seed": result.get("seed"),
- "duration": result.get("duration")
- }
- }
- elif result.get("status") == "failed":
- task_data["error"] = result.get("error", "Unknown error")
-
- self.cache.set(task_id, task_data)
- return task_data
-
- except httpx.HTTPStatusError as e:
- if e.response.status_code == 404:
- return {
- "task_id": task_id,
- "status": "failed",
- "error": "Task not found"
- }
- return {
- "task_id": task_id,
- "status": "failed",
- "error": f"HTTP {e.response.status_code}: {e.response.text}"
- }
- except Exception as e:
- return {
- "task_id": task_id,
- "status": "failed",
- "error": str(e)
- }
-
- async def close(self):
- """关闭客户端"""
- await self.client.aclose()
-
- def cleanup_cache(self):
- """清理过期缓存"""
- self.cache.cleanup()
|