liblibai_client.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. import os
  2. import hmac
  3. import hashlib
  4. import base64
  5. import uuid
  6. import time
  7. import requests
  8. import json
  9. from typing import Optional, Dict, Any, List
  10. from dotenv import load_dotenv
  11. load_dotenv()
  12. TEMPLATE_UUID = "e10adc3949ba59abbe56e057f20f883e"
  13. INSTANT_ID_TEMPLATE_UUID = "7d888009f81d4252a7c458c874cd017f"
  14. CHECKPOINT_ID = os.getenv("LIBLIBAI_DEFAULT_MODEL", "0ea388c7eb854be3ba3c6f65aac6bfd3")
  15. class LibLibAIClient:
  16. def __init__(self):
  17. self.access_key = os.getenv("LIBLIBAI_ACCESS_KEY")
  18. self.secret_key = os.getenv("LIBLIBAI_SECRET_KEY")
  19. self.domain = os.getenv("LIBLIBAI_DOMAIN", "https://openapi.liblibai.cloud")
  20. if not self.access_key or not self.secret_key:
  21. raise ValueError("Missing LIBLIBAI_ACCESS_KEY or LIBLIBAI_SECRET_KEY")
  22. self.models = self._load_models_from_json()
  23. self.sdxl_canny = self._get_model_uuid("线稿类", "Canny(硬边缘)") or "b6806516962f4e1599a93ac4483c3d23"
  24. self.sdxl_softedge = self._get_model_uuid("线稿类", "SoftEdge(软边缘)") or "dda1a0c480bfab9833d9d9a1e4a71fff"
  25. self.sdxl_lineart = self._get_model_uuid("线稿类", "Lineart(线稿)") or "a0f01da42bf48b0ba02c86b6c26b5699"
  26. self.sdxl_openpose = self._get_model_uuid("姿态类", "OpenPose(姿态)") or "2fe4f992a81c5ccbdf8e9851c8c96ff2"
  27. self.sdxl_depth = "6349e9dae8814084bd9c1585d335c24c"
  28. def _load_models_from_json(self):
  29. base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
  30. json_path = os.path.join(base_dir, "data", "liblibai_controlnet_models.json")
  31. if os.path.exists(json_path):
  32. with open(json_path, "r", encoding="utf-8") as f:
  33. return json.load(f)
  34. return {}
  35. def _get_model_uuid(self, category, subtype, xl_only=True):
  36. if category in self.models and subtype in self.models[category]:
  37. for model in self.models[category][subtype]:
  38. if xl_only and model["base_algorithm"] != "基础算法 XL":
  39. continue
  40. return model["uuid"]
  41. return None
  42. def generate_auth_url(self, uri: str) -> str:
  43. ts = str(int(time.time() * 1000))
  44. nonce = uuid.uuid4().hex
  45. sign_str = f"{uri}&{ts}&{nonce}"
  46. dig = hmac.new(self.secret_key.encode(), sign_str.encode(), hashlib.sha1).digest()
  47. signature = base64.urlsafe_b64encode(dig).rstrip(b"=").decode()
  48. return f"{self.domain}{uri}?AccessKey={self.access_key}&Timestamp={ts}&SignatureNonce={nonce}&Signature={signature}"
  49. def get_upload_signature(self, filename: str = "image.png") -> Dict[str, Any]:
  50. uri = "/api/generate/upload/signature"
  51. url = self.generate_auth_url(uri)
  52. extension = filename.split(".")[-1] if "." in filename else "png"
  53. payload = {"name": filename, "extension": extension}
  54. resp = requests.post(url, json=payload)
  55. resp.raise_for_status()
  56. data = resp.json()
  57. if data.get("code") != 0:
  58. raise Exception(f"Get upload signature failed: {data}")
  59. return data["data"]
  60. def upload_image_to_oss(self, image_bytes: bytes, sig_data: Dict[str, Any]) -> str:
  61. data = {"key": sig_data["key"], "policy": sig_data["policy"], "x-oss-date": sig_data["xOssDate"],
  62. "x-oss-expires": sig_data["xOssExpires"], "x-oss-signature": sig_data["xOssSignature"],
  63. "x-oss-credential": sig_data["xOssCredential"], "x-oss-signature-version": sig_data["xOssSignatureVersion"]}
  64. files = {"file": ("image.png", image_bytes, "image/png")}
  65. resp = requests.post(sig_data["postUrl"], data=data, files=files)
  66. if resp.status_code != 204:
  67. raise Exception(f"Upload to OSS failed: {resp.status_code} {resp.text}")
  68. return f"{sig_data['postUrl']}/{sig_data['key']}"
  69. def upload_base64_image(self, base64_data: str) -> str:
  70. if base64_data.startswith("data:image"):
  71. base64_data = base64_data.split(",", 1)[1]
  72. image_bytes = base64.b64decode(base64_data)
  73. sig_data = self.get_upload_signature()
  74. return self.upload_image_to_oss(image_bytes, sig_data)
  75. def process_image_url(self, image: str) -> str:
  76. if image.startswith("http://") or image.startswith("https://"):
  77. # 如果本身就是 liblib 的图床,直接短路返回
  78. if "liblib" in image.lower() or "aliyuncs.com" in image.lower():
  79. return image
  80. # 否则必须先将外部公网 URL 拽下来,转存到 LibLib 的 OSS 中
  81. import httpx
  82. try:
  83. resp = httpx.get(image, timeout=30.0)
  84. resp.raise_for_status()
  85. sig_data = self.get_upload_signature()
  86. return self.upload_image_to_oss(resp.content, sig_data)
  87. except Exception as e:
  88. raise ValueError(f"Failed to fetch and upload external image URL: {e}")
  89. return self.upload_base64_image(image)
  90. def submit_task_payload(self, payload: dict) -> str:
  91. uri = "/api/generate/webui/text2img"
  92. url = self.generate_auth_url(uri)
  93. resp = requests.post(url, json=payload, timeout=15)
  94. data = resp.json()
  95. if data.get("code") != 0:
  96. raise Exception(f"Submit task failed: {data.get('msg', data)} (code: {data.get('code')})")
  97. return data["data"]["generateUuid"]
  98. def query_task_status(self, task_id: str) -> Dict[str, Any]:
  99. uri = "/api/generate/webui/status"
  100. url = self.generate_auth_url(uri)
  101. payload = {"generateUuid": task_id}
  102. resp = requests.post(url, json=payload)
  103. resp.raise_for_status()
  104. data = resp.json()
  105. if data.get("code") != 0:
  106. raise Exception(f"Query task failed: {data}")
  107. return data["data"]
  108. def wait_for_result(self, task_id: str, timeout: int = 300) -> Dict[str, Any]:
  109. start_time = time.time()
  110. while time.time() - start_time < timeout:
  111. task_data = self.query_task_status(task_id)
  112. status = task_data.get("generateStatus")
  113. if status == 5:
  114. images = [img["imageUrl"] for img in task_data.get("images", [])]
  115. return {"images": images, "task_id": task_id, "status": "success"}
  116. elif status in [6, 7]:
  117. return {"images": [], "task_id": task_id, "status": "failed", "detail": f"Status {status}"}
  118. time.sleep(5)
  119. return {"images": [], "task_id": task_id, "status": "timeout"}
  120. def get_model_version_info(self, version_uuid: str) -> dict:
  121. uri = "/api/model/version/get"
  122. auth_url = self.generate_auth_url(uri)
  123. payload = {"versionUuid": version_uuid}
  124. try:
  125. resp = requests.post(auth_url, json=payload, timeout=10)
  126. data = resp.json()
  127. if data.get("code") == 0:
  128. return data.get("data", {})
  129. return {}
  130. except Exception:
  131. return {}
  132. def search_models(self, keyword: str) -> Dict[str, Any]:
  133. url = "http://crawler.aiddit.com/crawler/liblib/keyword"
  134. payload = {"keyword": keyword}
  135. headers = {
  136. "Content-Type": "application/json",
  137. "Cookie": "_xsrf=2|c9d0a1bf|891f10d6ea5abc19d58be0d2fac84e6a|1774447752"
  138. }
  139. resp = requests.post(url, json=payload, headers=headers, timeout=15)
  140. resp.raise_for_status()
  141. return resp.json()
  142. def get_model_detail(self, content_link: str = None, uuid: str = None, version_uuid: str = None) -> Dict[str, Any]:
  143. url = "http://crawler.aiddit.com/crawler/liblib/detail"
  144. if not content_link:
  145. if not uuid or not version_uuid:
  146. raise ValueError("Must provide either content_link or uuid and version_uuid")
  147. content_link = f"https://www.liblib.art/modelinfo/{uuid}?from=search&versionUuid={version_uuid}"
  148. payload = {"content_link": content_link}
  149. headers = {
  150. "Content-Type": "application/json",
  151. "Cookie": "_xsrf=2|c9d0a1bf|891f10d6ea5abc19d58be0d2fac84e6a|1774447752"
  152. }
  153. resp = requests.post(url, json=payload, headers=headers, timeout=15)
  154. resp.raise_for_status()
  155. return resp.json()
  156. def generate_advanced(self, mode: str, prompt: str, image: Optional[str] = None,
  157. mask_image: Optional[str] = None, pose_image: Optional[str] = None,
  158. control_nets: Optional[List[Dict[str, Any]]] = None,
  159. negative_prompt: str = "lowres, bad anatomy, error",
  160. width: int = 512, height: int = 512, steps: int = 20,
  161. cfg_scale: float = 7.0, img_count: int = 1,
  162. base_model_uuid: Optional[str] = None) -> Dict[str, Any]:
  163. # Base shared params
  164. generate_params = {
  165. "checkPointId": base_model_uuid if base_model_uuid else CHECKPOINT_ID,
  166. "prompt": prompt,
  167. "negativePrompt": negative_prompt,
  168. "sampler": 15,
  169. "steps": steps,
  170. "cfgScale": float(cfg_scale),
  171. "width": width,
  172. "height": height,
  173. "imgCount": img_count
  174. }
  175. payload = {
  176. "templateUuid": TEMPLATE_UUID,
  177. "generateParams": generate_params
  178. }
  179. # text2img does not need modifications
  180. if mode == "text2img":
  181. pass
  182. elif mode == "img2img":
  183. if not image:
  184. raise ValueError("Image is required for img2img mode")
  185. img_url = self.process_image_url(image)
  186. generate_params["sourceImage"] = img_url
  187. generate_params["denoisingStrength"] = 0.5
  188. elif mode in ["canny", "softedge", "lineart", "openpose", "depth", "controlnet"]:
  189. cnets_to_process = []
  190. if control_nets and len(control_nets) > 0:
  191. cnets_to_process = control_nets
  192. else:
  193. if not image:
  194. raise ValueError(f"Image is required for {mode} mode")
  195. cnets_to_process = [{"mode": mode, "image": image}]
  196. final_cnet_configs = []
  197. for idx, cnet in enumerate(cnets_to_process):
  198. cnet_mode = cnet.get("mode")
  199. cnet_img = cnet.get("image")
  200. cnet_weight = float(cnet.get("weight", 1.0))
  201. if not cnet_img:
  202. continue
  203. img_url = self.process_image_url(cnet_img)
  204. cnet_config = {
  205. "unitOrder": idx + 1,
  206. "sourceImage": img_url,
  207. "width": width,
  208. "height": height,
  209. "controlWeight": cnet_weight,
  210. "startingControlStep": 0.0,
  211. "endingControlStep": 1.0,
  212. "pixelPerfect": 1,
  213. "controlMode": 0
  214. }
  215. if cnet_mode == "canny":
  216. cnet_config.update({
  217. "preprocessor": 1,
  218. "model": self.sdxl_canny,
  219. "annotationParameters": {
  220. "canny": {"preprocessorResolution": 512, "lowThreshold": 100, "highThreshold": 200}
  221. }
  222. })
  223. elif cnet_mode == "softedge":
  224. cnet_config.update({
  225. "preprocessor": 5,
  226. "model": self.sdxl_softedge,
  227. "annotationParameters": {
  228. "hed": {"preprocessorResolution": 512}
  229. }
  230. })
  231. elif cnet_mode == "lineart":
  232. cnet_config.update({
  233. "preprocessor": 32,
  234. "model": self.sdxl_lineart,
  235. "annotationParameters": {
  236. "lineart": {"preprocessorResolution": 512}
  237. }
  238. })
  239. elif cnet_mode == "openpose":
  240. cnet_config.update({
  241. "preprocessor": 14,
  242. "model": self.sdxl_openpose,
  243. "annotationParameters": {
  244. "openposeFull": {"preprocessorResolution": 512}
  245. }
  246. })
  247. elif cnet_mode == "depth":
  248. cnet_config.update({
  249. # Assuming Midas depth (9) or Zoe depth (39). Usually 9 is safe
  250. "preprocessor": 9,
  251. "model": self.sdxl_depth,
  252. "annotationParameters": {
  253. "depthMidas": {"preprocessorResolution": 512}
  254. }
  255. })
  256. final_cnet_configs.append(cnet_config)
  257. if not final_cnet_configs:
  258. raise ValueError("No valid control_nets processed.")
  259. generate_params["controlNet"] = final_cnet_configs
  260. elif mode == "inpaint":
  261. if not image or not mask_image:
  262. raise ValueError("Both image and mask_image are required for inpaint mode")
  263. img_url = self.process_image_url(image)
  264. mask_url = self.process_image_url(mask_image)
  265. generate_params["mode"] = 4
  266. generate_params["sourceImage"] = img_url
  267. generate_params["denoisingStrength"] = 0.5
  268. generate_params["inpaintParam"] = {
  269. "maskImage": mask_url,
  270. "maskBlur": 4,
  271. "inpaintArea": 0
  272. }
  273. elif mode == "instantid":
  274. if not image or not pose_image:
  275. raise ValueError("Both face image (image) and pose_image are required for instantid")
  276. payload["templateUuid"] = INSTANT_ID_TEMPLATE_UUID
  277. generate_params["sampler"] = 20
  278. face_img_url = self.process_image_url(image)
  279. pose_img_url = self.process_image_url(pose_image)
  280. generate_params["controlNet"] = [
  281. {
  282. "unitOrder": 1,
  283. "sourceImage": face_img_url,
  284. "width": 1080,
  285. "height": 1432
  286. },
  287. {
  288. "unitOrder": 2,
  289. "sourceImage": pose_img_url,
  290. "width": 1024,
  291. "height": 1024
  292. }
  293. ]
  294. else:
  295. raise ValueError(f"Unknown mode: {mode}")
  296. task_id = self.submit_task_payload(payload)
  297. return self.wait_for_result(task_id)
  298. # Legacy method for backwards compatibility
  299. def generate_image(self, *args, **kwargs):
  300. raise NotImplementedError("generate_image is deprecated. Use generate_advanced with mode='canny'")