test_liblibai_workflows.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. import os
  2. import sys
  3. import time
  4. import requests
  5. import json
  6. from typing import Dict, Any
  7. sys.path.append(os.path.join(os.path.dirname(__file__), '../tools/local/liblibai_controlnet'))
  8. from liblibai_client import LibLibAIClient
  9. # 常量配置 (根据文档)
  10. TEMPLATE_UUID = "e10adc3949ba59abbe56e057f20f883e" # SD1.5 & SDXL 通用自定义参数模板
  11. # 请确保这是一个有效的 SDXL 模型,这样才能匹配底下的 SDXL Canny 模型
  12. # 这里用的是代码里原本的 Checkpoint ID,请根据你们自己的 Liblib 模型库调整!
  13. DEFAULT_CHECKPOINT_ID = "0ea388c7eb854be3ba3c6f65aac6bfd3"
  14. class LibLibTestRunner:
  15. def __init__(self):
  16. self.client = LibLibAIClient()
  17. self.models = self._load_models_from_json()
  18. # 动态匹配基础算法 XL 的各种控制网模型
  19. self.sdxl_canny = self._get_model_uuid("线稿类", "Canny(硬边缘)", xl_only=True) or "b6806516962f4e1599a93ac4483c3d23"
  20. self.sdxl_softedge = self._get_model_uuid("线稿类", "SoftEdge(软边缘)", xl_only=True) or "dda1a0c480bfab9833d9d9a1e4a71fff"
  21. self.sdxl_lineart = self._get_model_uuid("线稿类", "Lineart(线稿)", xl_only=True) or "a0f01da42bf48b0ba02c86b6c26b5699"
  22. self.sdxl_openpose = self._get_model_uuid("姿态类", "OpenPose(姿态)", xl_only=True) or "2fe4f992a81c5ccbdf8e9851c8c96ff2"
  23. self._verify_models()
  24. def _load_models_from_json(self):
  25. json_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/liblibai_controlnet_models.json"))
  26. if os.path.exists(json_path):
  27. with open(json_path, "r", encoding="utf-8") as f:
  28. return json.load(f)
  29. return {}
  30. def _get_model_uuid(self, category, subtype, xl_only=True):
  31. if category in self.models and subtype in self.models[category]:
  32. for model in self.models[category][subtype]:
  33. if xl_only and model["base_algorithm"] != "基础算法 XL":
  34. continue
  35. return model["uuid"]
  36. return None
  37. def _verify_models(self):
  38. print("="*50)
  39. print("校验底模信息 (使用 api/model/version/get)")
  40. print("="*50)
  41. models_to_verify = {
  42. "CheckPoint / 底模": DEFAULT_CHECKPOINT_ID,
  43. }
  44. for name, uuid in models_to_verify.items():
  45. info = self.client.get_model_version_info(uuid)
  46. status = f"{info.get('model_name', 'Unknown')} (Base: {info.get('baseAlgo', 'Unknown')})" if info else "Failed to verify"
  47. print(f"- {name}: [{uuid}] -> {status}")
  48. print("\n")
  49. def _submit_task(self, payload: dict) -> str:
  50. url = self.client.generate_auth_url("/api/generate/webui/text2img")
  51. print(f"Submitting payload...")
  52. resp = requests.post(url, json=payload, timeout=10)
  53. data = resp.json()
  54. if data.get("code") != 0:
  55. raise Exception(f"Submit task failed: {data.get('msg')} (code: {data.get('code')})")
  56. return data["data"]["generateUuid"]
  57. def _wait_and_print_result(self, task_id: str):
  58. print(f"Task submitted successfully! Task ID: {task_id}")
  59. print("Waiting for result...")
  60. timeout = 300
  61. start_time = time.time()
  62. while time.time() - start_time < timeout:
  63. task_data = self.client.query_task_status(task_id)
  64. status = task_data.get("generateStatus")
  65. if status == 5: # Success
  66. images = [img["imageUrl"] for img in task_data.get("images", [])]
  67. print(f"\n[SUCCESSS] Generated images: {images}")
  68. # --- 新增: 自动下载图片 ---
  69. output_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "output"))
  70. os.makedirs(output_dir, exist_ok=True)
  71. for i, img_url in enumerate(images):
  72. try:
  73. img_resp = requests.get(img_url)
  74. img_resp.raise_for_status()
  75. timestamp = int(time.time())
  76. filename = f"workflow_result_{timestamp}_{i}.png"
  77. filepath = os.path.join(output_dir, filename)
  78. with open(filepath, "wb") as f:
  79. f.write(img_resp.content)
  80. print(f"[{i}] 已自动保存到本地: {filepath}")
  81. except Exception as e:
  82. print(f"图片自动下载失败: {e}")
  83. return
  84. elif status in [6, 7]: # Failed or Cancelled
  85. print(f"\n[FAILED] Task failed. Audit status: {task_data.get('auditStatus')}")
  86. return
  87. print(".", end="", flush=True)
  88. time.sleep(5)
  89. print("\n[TIMEOUT] Wait for task timed out.")
  90. # 1. 纯文生图测试
  91. def test_text2img(self):
  92. print("\n" + "="*50)
  93. print("TEST 1: 纯文生图 (Text to Image)")
  94. print("="*50)
  95. payload = {
  96. "templateUuid": TEMPLATE_UUID,
  97. "generateParams": {
  98. "checkPointId": DEFAULT_CHECKPOINT_ID,
  99. "prompt": "1girl, cute, beautiful landscape, masterpiece",
  100. "negativePrompt": "lowres, bad anatomy, text, error",
  101. "sampler": 15,
  102. "steps": 20,
  103. "cfgScale": 7.0,
  104. "width": 512,
  105. "height": 512,
  106. "imgCount": 1
  107. }
  108. }
  109. task_id = self._submit_task(payload)
  110. self._wait_and_print_result(task_id)
  111. # 2. 纯图生图测试
  112. def test_img2img(self):
  113. print("\n" + "="*50)
  114. print("TEST 2: 纯图生图 (Image to Image)")
  115. print("="*50)
  116. test_img_url = "https://liblibai-airship-temp.oss-cn-beijing.aliyuncs.com/aliyun-cn-prod/73ed6ae42b144d21bf566e05b5a6c138.png"
  117. payload = {
  118. "templateUuid": TEMPLATE_UUID,
  119. "generateParams": {
  120. "checkPointId": DEFAULT_CHECKPOINT_ID,
  121. "mode": 0, # 0 代表图生图
  122. "sourceImage": test_img_url,
  123. "denoisingStrength": 0.75, # 去噪强度
  124. "prompt": "white line art, cat",
  125. "negativePrompt": "lowres, bad anatomy, error",
  126. "sampler": 15,
  127. "steps": 20,
  128. "cfgScale": 7.0,
  129. "width": 512,
  130. "height": 512,
  131. "imgCount": 1
  132. }
  133. }
  134. task_id = self._submit_task(payload)
  135. self._wait_and_print_result(task_id)
  136. # 3. 文生图 + ControlNet (Canny) 测试
  137. def test_text2img_controlnet(self):
  138. print("\n" + "="*50)
  139. print("TEST 3: 文生图 + ControlNet (Text2Img with ControlNet Canny)")
  140. print("="*50)
  141. print("注意: 如果默认 Checkpoint 是 SD1.5, 下面配置的 SDXL Canny 模型可能会导致访问拒绝!")
  142. test_img_url = "https://liblibai-airship-temp.oss-cn-beijing.aliyuncs.com/aliyun-cn-prod/73ed6ae42b144d21bf566e05b5a6c138.png"
  143. payload = {
  144. "templateUuid": TEMPLATE_UUID,
  145. "generateParams": {
  146. "checkPointId": DEFAULT_CHECKPOINT_ID,
  147. "prompt": "simple white line art, cat, black background",
  148. "negativePrompt": "lowres, bad anatomy",
  149. "sampler": 15,
  150. "steps": 20,
  151. "cfgScale": 7.0,
  152. "width": 512,
  153. "height": 512,
  154. "imgCount": 1,
  155. "controlNet": [{
  156. "unitOrder": 1,
  157. "sourceImage": test_img_url,
  158. "width": 512,
  159. "height": 512,
  160. "preprocessor": 1, # 1 = Canny
  161. "model": self.sdxl_canny,
  162. "controlWeight": 1.0,
  163. "startingControlStep": 0.0,
  164. "endingControlStep": 1.0,
  165. "pixelPerfect": 1,
  166. "controlMode": 0,
  167. "annotationParameters": {
  168. "canny": {
  169. "preprocessorResolution": 512,
  170. "lowThreshold": 100,
  171. "highThreshold": 200
  172. }
  173. }
  174. }]
  175. }
  176. }
  177. task_id = self._submit_task(payload)
  178. self._wait_and_print_result(task_id)
  179. # 4. 文生图 + 边缘 (SoftEdge/HED)
  180. def test_text2img_softedge(self):
  181. print("\n" + "="*50)
  182. print("TEST 4: 文生图 + SoftEdge (软边缘 控制网)")
  183. print("="*50)
  184. # 测试用图片
  185. test_img_url = "https://liblibai-airship-temp.oss-cn-beijing.aliyuncs.com/aliyun-cn-prod/73ed6ae42b144d21bf566e05b5a6c138.png"
  186. payload = {
  187. "templateUuid": TEMPLATE_UUID,
  188. "generateParams": {
  189. "checkPointId": DEFAULT_CHECKPOINT_ID,
  190. "prompt": "Soft edge artwork, beautiful soft lighting, cat",
  191. "negativePrompt": "lowres, bad anatomy",
  192. "sampler": 15,
  193. "steps": 20,
  194. "cfgScale": 7.0,
  195. "width": 512,
  196. "height": 512,
  197. "imgCount": 1,
  198. "controlNet": [{
  199. "unitOrder": 1,
  200. "sourceImage": test_img_url,
  201. "width": 512,
  202. "height": 512,
  203. "preprocessor": 5, # 5 = HED / 软边缘
  204. "model": self.sdxl_softedge,
  205. "controlWeight": 1.0,
  206. "startingControlStep": 0.0,
  207. "endingControlStep": 1.0,
  208. "pixelPerfect": 1,
  209. "controlMode": 0,
  210. "annotationParameters": {
  211. "hed": {
  212. "preprocessorResolution": 512
  213. }
  214. }
  215. }]
  216. }
  217. }
  218. task_id = self._submit_task(payload)
  219. self._wait_and_print_result(task_id)
  220. # 5. 文生图 + 线稿 (Lineart)
  221. def test_text2img_lineart(self):
  222. print("\n" + "="*50)
  223. print("TEST 5: 文生图 + Lineart (线稿 控制网)")
  224. print("="*50)
  225. test_img_url = "https://liblibai-airship-temp.oss-cn-beijing.aliyuncs.com/aliyun-cn-prod/73ed6ae42b144d21bf566e05b5a6c138.png"
  226. payload = {
  227. "templateUuid": TEMPLATE_UUID,
  228. "generateParams": {
  229. "checkPointId": DEFAULT_CHECKPOINT_ID,
  230. "prompt": "Detailed coloring, colorful fantasy style, masterpiece, cat",
  231. "negativePrompt": "lowres",
  232. "sampler": 15,
  233. "steps": 20,
  234. "cfgScale": 7.0,
  235. "width": 512,
  236. "height": 512,
  237. "imgCount": 1,
  238. "controlNet": [{
  239. "unitOrder": 1,
  240. "sourceImage": test_img_url,
  241. "width": 512,
  242. "height": 512,
  243. "preprocessor": 32, # 32 = Lineart Standard
  244. "model": self.sdxl_lineart,
  245. "controlWeight": 1.0,
  246. "startingControlStep": 0.0,
  247. "endingControlStep": 1.0,
  248. "pixelPerfect": 1,
  249. "controlMode": 0,
  250. "annotationParameters": {
  251. "lineart": {
  252. "preprocessorResolution": 512
  253. }
  254. }
  255. }]
  256. }
  257. }
  258. task_id = self._submit_task(payload)
  259. self._wait_and_print_result(task_id)
  260. # 6. 文生图 + 骨骼 (OpenPose)
  261. def test_text2img_openpose(self):
  262. print("\n" + "="*50)
  263. print("TEST 6: 文生图 + OpenPose (骨骼 控制网)")
  264. print("="*50)
  265. # 测试用图片(最好使用含有人物动作的图)
  266. test_img_url = "https://liblibai-airship-temp.oss-cn-beijing.aliyuncs.com/aliyun-cn-prod/73ed6ae42b144d21bf566e05b5a6c138.png"
  267. payload = {
  268. "templateUuid": TEMPLATE_UUID,
  269. "generateParams": {
  270. "checkPointId": DEFAULT_CHECKPOINT_ID,
  271. "prompt": "1girl, dancing pose, beautiful dress",
  272. "negativePrompt": "lowres",
  273. "sampler": 15,
  274. "steps": 20,
  275. "cfgScale": 7.0,
  276. "width": 512,
  277. "height": 512,
  278. "imgCount": 1,
  279. "controlNet": [{
  280. "unitOrder": 1,
  281. "sourceImage": test_img_url,
  282. "width": 512,
  283. "height": 512,
  284. "preprocessor": 14, # 14 = OpenPose Full
  285. "model": self.sdxl_openpose,
  286. "controlWeight": 1.0,
  287. "startingControlStep": 0.0,
  288. "endingControlStep": 1.0,
  289. "pixelPerfect": 1,
  290. "controlMode": 0,
  291. "annotationParameters": {
  292. "openposeFull": {
  293. "preprocessorResolution": 512
  294. }
  295. }
  296. }]
  297. }
  298. }
  299. task_id = self._submit_task(payload)
  300. self._wait_and_print_result(task_id)
  301. # 7. 局部重绘 (Inpaint Mode 4)
  302. def test_inpaint_mode4(self):
  303. print("\n" + "="*50)
  304. print("TEST 7: 局部重绘 (Inpaint Mode 4 蒙版重绘)")
  305. print("="*50)
  306. test_img_url = "https://liblibai-airship-temp.oss-cn-beijing.aliyuncs.com/aliyun-cn-prod/73ed6ae42b144d21bf566e05b5a6c138.png"
  307. test_mask_url = test_img_url # 仅作演示,实际应用中应当是一个黑底白色的蒙版图片
  308. payload = {
  309. # Inpainting 可能需要特殊模板,如不兼容请参考文档
  310. "templateUuid": TEMPLATE_UUID,
  311. "generateParams": {
  312. "checkPointId": DEFAULT_CHECKPOINT_ID,
  313. "mode": 4, # 4 = Inpaint 蒙版重绘
  314. "sourceImage": test_img_url,
  315. "denoisingStrength": 0.5,
  316. "prompt": "A completely different background, sci-fi city",
  317. "negativePrompt": "lowres",
  318. "sampler": 15,
  319. "steps": 20,
  320. "cfgScale": 7.0,
  321. "width": 512,
  322. "height": 512,
  323. "imgCount": 1,
  324. "inpaintParam": {
  325. "maskImage": test_mask_url,
  326. "maskBlur": 4,
  327. "inpaintArea": 0
  328. }
  329. }
  330. }
  331. task_id = self._submit_task(payload)
  332. self._wait_and_print_result(task_id)
  333. # 8. 人像换脸 (InstantID)
  334. def test_instantid_faceswap(self):
  335. print("\n" + "="*50)
  336. print("TEST 8: 人像换脸 (InstantID)")
  337. print("="*50)
  338. # Note: InstantID faceswap uses a UNIQUE templateUuid
  339. INSTANT_ID_TEMPLATE_UUID = "7d888009f81d4252a7c458c874cd017f"
  340. face_img_url = "https://liblibai-online.liblib.cloud/img/081e9f07d9bd4c2ba090efde163518f9/49943c0b-4d79-4e2f-8c55-bc1e5b8c69d8.png"
  341. pose_img_url = "https://liblibai-online.liblib.cloud/img/081e9f07d9bd4c2ba090efde163518f9/e713676d-baaa-4dac-99b9-d5d814a29f9f.png"
  342. payload = {
  343. "templateUuid": INSTANT_ID_TEMPLATE_UUID,
  344. "generateParams": {
  345. "checkPointId": DEFAULT_CHECKPOINT_ID, # 仅 XL模型支持人像换脸
  346. "prompt": "Asian portrait,A young woman wearing a green baseball cap, close shot, background is coffee store, masterpiece, best quality, ultra resolution",
  347. "width": 768,
  348. "height": 1152,
  349. "sampler": 20,
  350. "steps": 35,
  351. "cfgScale": 2.0,
  352. "imgCount": 1,
  353. "controlNet": [
  354. {
  355. "unitOrder": 1, # 第一步:先识别要用的人像人脸
  356. "sourceImage": face_img_url,
  357. "width": 1080,
  358. "height": 1432
  359. },
  360. {
  361. "unitOrder": 2, # 第二步:再识别要参考的人物面部朝向
  362. "sourceImage": pose_img_url,
  363. "width": 1024,
  364. "height": 1024
  365. }
  366. ]
  367. }
  368. }
  369. task_id = self._submit_task(payload)
  370. self._wait_and_print_result(task_id)
  371. if __name__ == "__main__":
  372. try:
  373. runner = LibLibTestRunner()
  374. print("选择要测试的模式:")
  375. print("1. 纯文生图 (Text2Img)")
  376. print("2. 纯图生图 (Img2Img)")
  377. print("3. 硬边缘控制 (Canny ControlNet)")
  378. print("4. 软边缘控制 (SoftEdge ControlNet)")
  379. print("5. 线稿控制 (Lineart ControlNet)")
  380. print("6. 骨骼控制 (OpenPose ControlNet)")
  381. print("7. 局部重绘 (Inpaint Mode=4)")
  382. print("8. 人像换脸 (InstantID)")
  383. print("9. 全部测试 (All)")
  384. if len(sys.argv) > 1:
  385. choice = sys.argv[1]
  386. else:
  387. # 默认测试文生图以验证 API Key 最基本权限
  388. choice = "1"
  389. if choice == "1":
  390. runner.test_text2img()
  391. elif choice == "2":
  392. runner.test_img2img()
  393. elif choice == "3":
  394. runner.test_text2img_controlnet()
  395. elif choice == "4":
  396. runner.test_text2img_softedge()
  397. elif choice == "5":
  398. runner.test_text2img_lineart()
  399. elif choice == "6":
  400. runner.test_text2img_openpose()
  401. elif choice == "7":
  402. runner.test_inpaint_mode4()
  403. elif choice == "8":
  404. runner.test_instantid_faceswap()
  405. elif choice == "9":
  406. runner.test_text2img()
  407. runner.test_img2img()
  408. runner.test_text2img_controlnet()
  409. # 省略后面的测试避免同时发生太多请求...
  410. else:
  411. print("Unknown choice.")
  412. except Exception as e:
  413. print(f"\n[FATAL ERROR] {e}")