test_controlnet.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import os
  2. from sorawm.iopaint.const import SD_CONTROLNET_CHOICES
  3. from sorawm.iopaint.tests.utils import (
  4. assert_equal,
  5. check_device,
  6. current_dir,
  7. get_config,
  8. )
  9. os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
  10. from pathlib import Path
  11. import pytest
  12. import torch
  13. from sorawm.iopaint.model_manager import ModelManager
  14. from sorawm.iopaint.schema import HDStrategy, SDSampler
  15. model_name = "runwayml/stable-diffusion-inpainting"
  16. def convert_controlnet_method_name(name):
  17. return name.replace("/", "--")
  18. @pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
  19. @pytest.mark.parametrize("controlnet_method", [SD_CONTROLNET_CHOICES[0]])
  20. def test_runway_sd_1_5(device, controlnet_method):
  21. sd_steps = check_device(device)
  22. model = ModelManager(
  23. name=model_name,
  24. device=torch.device(device),
  25. disable_nsfw=True,
  26. sd_cpu_textencoder=device == "cuda",
  27. enable_controlnet=True,
  28. controlnet_method=controlnet_method,
  29. )
  30. cfg = get_config(
  31. prompt="a fox sitting on a bench",
  32. sd_steps=sd_steps,
  33. enable_controlnet=True,
  34. controlnet_conditioning_scale=0.5,
  35. controlnet_method=controlnet_method,
  36. )
  37. name = f"device_{device}"
  38. assert_equal(
  39. model,
  40. cfg,
  41. f"sd_controlnet_{convert_controlnet_method_name(controlnet_method)}_{name}.png",
  42. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  43. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  44. )
  45. @pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
  46. def test_controlnet_switch(device):
  47. sd_steps = check_device(device)
  48. model = ModelManager(
  49. name=model_name,
  50. device=torch.device(device),
  51. disable_nsfw=True,
  52. sd_cpu_textencoder=False,
  53. cpu_offload=True,
  54. enable_controlnet=True,
  55. controlnet_method="lllyasviel/control_v11p_sd15_canny",
  56. )
  57. cfg = get_config(
  58. prompt="a fox sitting on a bench",
  59. sd_steps=sd_steps,
  60. enable_controlnet=True,
  61. controlnet_method="lllyasviel/control_v11f1p_sd15_depth",
  62. )
  63. assert_equal(
  64. model,
  65. cfg,
  66. f"controlnet_switch_canny_to_depth_device_{device}.png",
  67. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  68. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  69. fx=1.2,
  70. )
  71. @pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
  72. @pytest.mark.parametrize(
  73. "local_file", ["sd-v1-5-inpainting.ckpt", "v1-5-pruned-emaonly.safetensors"]
  74. )
  75. def test_local_file_path(device, local_file):
  76. sd_steps = check_device(device)
  77. controlnet_kwargs = dict(
  78. enable_controlnet=True,
  79. controlnet_method=SD_CONTROLNET_CHOICES[0],
  80. )
  81. model = ModelManager(
  82. name=local_file,
  83. device=torch.device(device),
  84. disable_nsfw=True,
  85. sd_cpu_textencoder=False,
  86. cpu_offload=True,
  87. **controlnet_kwargs,
  88. )
  89. cfg = get_config(
  90. prompt="a fox sitting on a bench",
  91. sd_steps=sd_steps,
  92. **controlnet_kwargs,
  93. )
  94. name = f"device_{device}"
  95. assert_equal(
  96. model,
  97. cfg,
  98. f"{convert_controlnet_method_name(controlnet_kwargs['controlnet_method'])}_local_model_{name}.png",
  99. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  100. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  101. )