liblibai_client.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import os
  2. import hmac
  3. import hashlib
  4. import base64
  5. import uuid
  6. import time
  7. import requests
  8. from typing import Optional, Dict, Any, List
  9. from dotenv import load_dotenv
  10. load_dotenv()
  11. TEMPLATE_UUID = "e10adc3949ba59abbe56e057f20f883e"
  12. CHECKPOINT_ID = "0ea388c7eb854be3ba3c6f65aac6bfd3"
  13. CANNY_MODEL_ID = "b6806516962f4e1599a93ac4483c3d23"
  14. class LibLibAIClient:
  15. def __init__(self):
  16. self.access_key = os.getenv("LIBLIBAI_ACCESS_KEY")
  17. self.secret_key = os.getenv("LIBLIBAI_SECRET_KEY")
  18. self.domain = os.getenv("LIBLIBAI_DOMAIN", "https://openapi.liblibai.cloud")
  19. if not self.access_key or not self.secret_key:
  20. raise ValueError("Missing LIBLIBAI_ACCESS_KEY or LIBLIBAI_SECRET_KEY")
  21. def generate_auth_url(self, uri: str) -> str:
  22. ts = str(int(time.time() * 1000))
  23. nonce = uuid.uuid4().hex
  24. sign_str = f"{uri}&{ts}&{nonce}"
  25. dig = hmac.new(self.secret_key.encode(), sign_str.encode(), hashlib.sha1).digest()
  26. signature = base64.urlsafe_b64encode(dig).rstrip(b"=").decode()
  27. return f"{self.domain}{uri}?AccessKey={self.access_key}&Timestamp={ts}&SignatureNonce={nonce}&Signature={signature}"
  28. def get_upload_signature(self, filename: str = "image.png") -> Dict[str, Any]:
  29. uri = "/api/generate/upload/signature"
  30. url = self.generate_auth_url(uri)
  31. extension = filename.split(".")[-1] if "." in filename else "png"
  32. payload = {"name": filename, "extension": extension}
  33. resp = requests.post(url, json=payload)
  34. resp.raise_for_status()
  35. data = resp.json()
  36. if data.get("code") != 0:
  37. raise Exception(f"Get upload signature failed: {data}")
  38. return data["data"]
  39. def upload_image_to_oss(self, image_bytes: bytes, sig_data: Dict[str, Any]) -> str:
  40. data = {"key": sig_data["key"], "policy": sig_data["policy"], "x-oss-date": sig_data["xOssDate"],
  41. "x-oss-expires": sig_data["xOssExpires"], "x-oss-signature": sig_data["xOssSignature"],
  42. "x-oss-credential": sig_data["xOssCredential"], "x-oss-signature-version": sig_data["xOssSignatureVersion"]}
  43. files = {"file": ("image.png", image_bytes, "image/png")}
  44. resp = requests.post(sig_data["postUrl"], data=data, files=files)
  45. if resp.status_code != 204:
  46. raise Exception(f"Upload to OSS failed: {resp.status_code} {resp.text}")
  47. return f"{sig_data["postUrl"]}/{sig_data["key"]}"
  48. def upload_base64_image(self, base64_data: str) -> str:
  49. if base64_data.startswith("data:image"):
  50. base64_data = base64_data.split(",", 1)[1]
  51. image_bytes = base64.b64decode(base64_data)
  52. sig_data = self.get_upload_signature()
  53. return self.upload_image_to_oss(image_bytes, sig_data)
  54. def submit_controlnet_task(self, image_url: str, prompt: str, negative_prompt: str, width: int, height: int,
  55. steps: int, cfg_scale: float, img_count: int, control_weight: float,
  56. preprocessor: int, canny_low: int, canny_high: int) -> str:
  57. uri = "/api/generate/webui/text2img"
  58. url = self.generate_auth_url(uri)
  59. payload = {
  60. "templateUuid": TEMPLATE_UUID,
  61. "generateParams": {
  62. "checkPointId": CHECKPOINT_ID,
  63. "prompt": prompt,
  64. "negativePrompt": negative_prompt,
  65. "sampler": 15,
  66. "steps": steps,
  67. "cfgScale": float(cfg_scale), # 确保是浮点数
  68. "width": width,
  69. "height": height,
  70. "imgCount": img_count,
  71. "seed": -1,
  72. "controlNet": [{
  73. "unitOrder": 1,
  74. "sourceImage": image_url,
  75. "width": width,
  76. "height": height,
  77. "preprocessor": preprocessor,
  78. "annotationParameters": {
  79. "canny": {
  80. "preprocessorResolution": 512,
  81. "lowThreshold": canny_low,
  82. "highThreshold": canny_high
  83. }
  84. },
  85. "model": CANNY_MODEL_ID,
  86. "controlWeight": float(control_weight), # 确保是浮点数
  87. "startingControlStep": 0.0,
  88. "endingControlStep": 1.0,
  89. "pixelPerfect": 1,
  90. "controlMode": 0,
  91. "resizeMode": 1
  92. }]
  93. }
  94. }
  95. resp = requests.post(url, json=payload, timeout=10)
  96. data = resp.json()
  97. if data.get("code") != 0:
  98. raise Exception(f"Submit task failed: {data.get('msg', data)}")
  99. return data["data"]["generateUuid"]
  100. def query_task_status(self, task_id: str) -> Dict[str, Any]:
  101. uri = "/api/generate/webui/status"
  102. url = self.generate_auth_url(uri)
  103. payload = {"generateUuid": task_id}
  104. resp = requests.post(url, json=payload)
  105. resp.raise_for_status()
  106. data = resp.json()
  107. if data.get("code") != 0:
  108. raise Exception(f"Query task failed: {data}")
  109. return data["data"]
  110. def generate_image(self, image: str, prompt: str, negative_prompt: str = "lowres, bad anatomy, text, error",
  111. width: int = 512, height: int = 512, steps: int = 20, cfg_scale: float = 7,
  112. img_count: int = 1, control_weight: float = 1.0, preprocessor: int = 1,
  113. canny_low: int = 100, canny_high: int = 200) -> Dict[str, Any]:
  114. if image.startswith("http://") or image.startswith("https://"):
  115. image_url = image
  116. else:
  117. image_url = self.upload_base64_image(image)
  118. task_id = self.submit_controlnet_task(image_url, prompt, negative_prompt, width, height, steps,
  119. cfg_scale, img_count, control_weight, preprocessor, canny_low, canny_high)
  120. timeout = 300
  121. start_time = time.time()
  122. while time.time() - start_time < timeout:
  123. task_data = self.query_task_status(task_id)
  124. status = task_data.get("generateStatus")
  125. if status == 5:
  126. images = [img["imageUrl"] for img in task_data.get("images", [])]
  127. return {"images": images, "task_id": task_id, "status": "success"}
  128. elif status in [6, 7]:
  129. return {"images": [], "task_id": task_id, "status": "failed"}
  130. time.sleep(5)
  131. return {"images": [], "task_id": task_id, "status": "timeout"}