test_liblib_depth_pipeline.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import sys
  2. import os
  3. import json
  4. import base64
  5. sys.stdout.reconfigure(encoding='utf-8')
  6. script_dir = os.path.dirname(os.path.abspath(__file__))
  7. tools_dir = os.path.join(script_dir, '..', 'tools', 'local', 'liblibai_controlnet')
  8. sys.path.append(tools_dir)
  9. try:
  10. from liblibai_client import LibLibAIClient, TEMPLATE_UUID, INSTANT_ID_TEMPLATE_UUID
  11. except ImportError:
  12. print(f"Failed to import LibLibAIClient. Make sure the path {tools_dir} is correct.")
  13. sys.exit(1)
  14. def main():
  15. client = LibLibAIClient()
  16. # 1. Search for a checkpoint's UUID
  17. keyword = "Juggernaut XL"
  18. print(f"1. Searching for Checkpoint matching '{keyword}'...")
  19. search_res = client.search_models(keyword)
  20. if not search_res or search_res.get('code') != 0 or not search_res.get('data', {}).get('data'):
  21. print("Search failed or no models found.")
  22. print(search_res)
  23. return
  24. first_model = search_res['data']['data'][0]
  25. uuid = first_model['uuid']
  26. version_uuid = first_model['versionUuid']
  27. model_name = first_model.get('name')
  28. base_types = first_model.get('baseType', [])
  29. print(f" -> Found Checkpoint: '{model_name}'")
  30. print(f" -> UUID: {uuid}")
  31. print(f" -> Version UUID: {version_uuid}")
  32. print(f" -> baseType: {base_types} (1=1.5, 2=XL etc.)\n")
  33. # 2. Match local template_id
  34. print(f"2. Matching local template_id based on baseType ...")
  35. print(f" -> Client is currently configured to use TEMPLATE_UUID: {TEMPLATE_UUID}\n")
  36. # 3. Read local image and encode it to base64
  37. local_image_path = os.path.join(script_dir, "..", "depth_map.png")
  38. print(f"3. Reading local reference image: {local_image_path}")
  39. if not os.path.exists(local_image_path):
  40. print(f"Error: {local_image_path} does not exist.")
  41. return
  42. with open(local_image_path, "rb") as f:
  43. image_bytes = f.read()
  44. b64_image = base64.b64encode(image_bytes).decode('utf-8')
  45. image_payload = f"data:image/png;base64,{b64_image}"
  46. # 4. Request ControlNet Image Generation (mode='depth')
  47. print("4. Executing controlnet generation (using Depth + found Checkpoint)...")
  48. try:
  49. gen_res = client.generate_advanced(
  50. mode="depth", # Changed to depth since it's a depth map
  51. prompt="A masterpiece, best quality, an easel on the left, a beautiful girl holding a clipboard on the right, highly detailed, photorealistic, cinematic lighting, 8k resolution",
  52. image=image_payload,
  53. base_model_uuid=version_uuid,
  54. width=1024,
  55. height=1024,
  56. steps=5,
  57. cfg_scale=1.5,
  58. control_nets=[{
  59. "mode": "depth",
  60. "image": image_payload,
  61. "weight": 0.4
  62. }]
  63. )
  64. print("\n=== Generation Task Success ===")
  65. print(json.dumps(gen_res, indent=2, ensure_ascii=False))
  66. except Exception as e:
  67. print(f"\n=== Generation Error ===")
  68. print(f"Error: {e}")
  69. if __name__ == '__main__':
  70. main()