test_model_switch.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import os
  2. from sorawm.iopaint.schema import InpaintRequest
  3. os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
  4. import torch
  5. from sorawm.iopaint.model_manager import ModelManager
  6. def test_model_switch():
  7. model = ModelManager(
  8. name="runwayml/stable-diffusion-inpainting",
  9. enable_controlnet=True,
  10. controlnet_method="lllyasviel/control_v11p_sd15_canny",
  11. device=torch.device("mps"),
  12. disable_nsfw=True,
  13. sd_cpu_textencoder=True,
  14. cpu_offload=False,
  15. )
  16. model.switch("lama")
  17. def test_controlnet_switch_onoff(caplog):
  18. name = "runwayml/stable-diffusion-inpainting"
  19. model = ModelManager(
  20. name=name,
  21. enable_controlnet=True,
  22. controlnet_method="lllyasviel/control_v11p_sd15_canny",
  23. device=torch.device("mps"),
  24. disable_nsfw=True,
  25. sd_cpu_textencoder=True,
  26. cpu_offload=False,
  27. )
  28. model.switch_controlnet_method(
  29. InpaintRequest(
  30. name=name,
  31. enable_controlnet=False,
  32. )
  33. )
  34. assert "Disable controlnet" in caplog.text
  35. def test_switch_controlnet_method(caplog):
  36. name = "runwayml/stable-diffusion-inpainting"
  37. old_method = "lllyasviel/control_v11p_sd15_canny"
  38. new_method = "lllyasviel/control_v11p_sd15_openpose"
  39. model = ModelManager(
  40. name=name,
  41. enable_controlnet=True,
  42. controlnet_method=old_method,
  43. device=torch.device("mps"),
  44. disable_nsfw=True,
  45. sd_cpu_textencoder=True,
  46. cpu_offload=False,
  47. )
  48. model.switch_controlnet_method(
  49. InpaintRequest(
  50. name=name,
  51. enable_controlnet=True,
  52. controlnet_method=new_method,
  53. )
  54. )
  55. assert f"Switch Controlnet method from {old_method} to {new_method}" in caplog.text