import os import hmac import hashlib import base64 import uuid import time import requests from typing import Optional, Dict, Any, List from dotenv import load_dotenv load_dotenv() TEMPLATE_UUID = "e10adc3949ba59abbe56e057f20f883e" CHECKPOINT_ID = "0ea388c7eb854be3ba3c6f65aac6bfd3" CANNY_MODEL_ID = "b6806516962f4e1599a93ac4483c3d23" class LibLibAIClient: def __init__(self): self.access_key = os.getenv("LIBLIBAI_ACCESS_KEY") self.secret_key = os.getenv("LIBLIBAI_SECRET_KEY") self.domain = os.getenv("LIBLIBAI_DOMAIN", "https://openapi.liblibai.cloud") if not self.access_key or not self.secret_key: raise ValueError("Missing LIBLIBAI_ACCESS_KEY or LIBLIBAI_SECRET_KEY") def generate_auth_url(self, uri: str) -> str: ts = str(int(time.time() * 1000)) nonce = uuid.uuid4().hex sign_str = f"{uri}&{ts}&{nonce}" dig = hmac.new(self.secret_key.encode(), sign_str.encode(), hashlib.sha1).digest() signature = base64.urlsafe_b64encode(dig).rstrip(b"=").decode() return f"{self.domain}{uri}?AccessKey={self.access_key}&Timestamp={ts}&SignatureNonce={nonce}&Signature={signature}" def get_upload_signature(self, filename: str = "image.png") -> Dict[str, Any]: uri = "/api/generate/upload/signature" url = self.generate_auth_url(uri) extension = filename.split(".")[-1] if "." in filename else "png" payload = {"name": filename, "extension": extension} resp = requests.post(url, json=payload) resp.raise_for_status() data = resp.json() if data.get("code") != 0: raise Exception(f"Get upload signature failed: {data}") return data["data"] def upload_image_to_oss(self, image_bytes: bytes, sig_data: Dict[str, Any]) -> str: data = {"key": sig_data["key"], "policy": sig_data["policy"], "x-oss-date": sig_data["xOssDate"], "x-oss-expires": sig_data["xOssExpires"], "x-oss-signature": sig_data["xOssSignature"], "x-oss-credential": sig_data["xOssCredential"], "x-oss-signature-version": sig_data["xOssSignatureVersion"]} files = {"file": ("image.png", image_bytes, "image/png")} resp = requests.post(sig_data["postUrl"], data=data, files=files) if resp.status_code != 204: raise Exception(f"Upload to OSS failed: {resp.status_code} {resp.text}") return f"{sig_data["postUrl"]}/{sig_data["key"]}" def upload_base64_image(self, base64_data: str) -> str: if base64_data.startswith("data:image"): base64_data = base64_data.split(",", 1)[1] image_bytes = base64.b64decode(base64_data) sig_data = self.get_upload_signature() return self.upload_image_to_oss(image_bytes, sig_data) def submit_controlnet_task(self, image_url: str, prompt: str, negative_prompt: str, width: int, height: int, steps: int, cfg_scale: float, img_count: int, control_weight: float, preprocessor: int, canny_low: int, canny_high: int) -> str: uri = "/api/generate/webui/text2img" url = self.generate_auth_url(uri) payload = { "templateUuid": TEMPLATE_UUID, "generateParams": { "checkPointId": CHECKPOINT_ID, "prompt": prompt, "negativePrompt": negative_prompt, "sampler": 15, "steps": steps, "cfgScale": float(cfg_scale), # 确保是浮点数 "width": width, "height": height, "imgCount": img_count, "seed": -1, "controlNet": [{ "unitOrder": 1, "sourceImage": image_url, "width": width, "height": height, "preprocessor": preprocessor, "annotationParameters": { "canny": { "preprocessorResolution": 512, "lowThreshold": canny_low, "highThreshold": canny_high } }, "model": CANNY_MODEL_ID, "controlWeight": float(control_weight), # 确保是浮点数 "startingControlStep": 0.0, "endingControlStep": 1.0, "pixelPerfect": 1, "controlMode": 0, "resizeMode": 1 }] } } resp = requests.post(url, json=payload, timeout=10) data = resp.json() if data.get("code") != 0: raise Exception(f"Submit task failed: {data.get('msg', data)}") return data["data"]["generateUuid"] def query_task_status(self, task_id: str) -> Dict[str, Any]: uri = "/api/generate/webui/status" url = self.generate_auth_url(uri) payload = {"generateUuid": task_id} resp = requests.post(url, json=payload) resp.raise_for_status() data = resp.json() if data.get("code") != 0: raise Exception(f"Query task failed: {data}") return data["data"] def generate_image(self, image: str, prompt: str, negative_prompt: str = "lowres, bad anatomy, text, error", width: int = 512, height: int = 512, steps: int = 20, cfg_scale: float = 7, img_count: int = 1, control_weight: float = 1.0, preprocessor: int = 1, canny_low: int = 100, canny_high: int = 200) -> Dict[str, Any]: if image.startswith("http://") or image.startswith("https://"): image_url = image else: image_url = self.upload_base64_image(image) task_id = self.submit_controlnet_task(image_url, prompt, negative_prompt, width, height, steps, cfg_scale, img_count, control_weight, preprocessor, canny_low, canny_high) timeout = 300 start_time = time.time() while time.time() - start_time < timeout: task_data = self.query_task_status(task_id) status = task_data.get("generateStatus") if status == 5: images = [img["imageUrl"] for img in task_data.get("images", [])] return {"images": images, "task_id": task_id, "status": "success"} elif status in [6, 7]: return {"images": [], "task_id": task_id, "status": "failed"} time.sleep(5) return {"images": [], "task_id": task_id, "status": "timeout"}