test_sdxl.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import os
  2. from sorawm.iopaint.tests.utils import check_device, current_dir
  3. os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
  4. import pytest
  5. import torch
  6. from sorawm.iopaint.model_manager import ModelManager
  7. from sorawm.iopaint.schema import HDStrategy, SDSampler
  8. from sorawm.iopaint.tests.test_model import assert_equal, get_config
  9. @pytest.mark.parametrize("device", ["cuda", "mps"])
  10. @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
  11. @pytest.mark.parametrize("sampler", [SDSampler.ddim])
  12. def test_sdxl(device, strategy, sampler):
  13. sd_steps = check_device(device)
  14. model = ModelManager(
  15. name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
  16. device=torch.device(device),
  17. disable_nsfw=True,
  18. sd_cpu_textencoder=False,
  19. )
  20. cfg = get_config(
  21. strategy=strategy,
  22. prompt="face of a fox, sitting on a bench",
  23. sd_steps=sd_steps,
  24. sd_strength=1.0,
  25. sd_guidance_scale=7.0,
  26. )
  27. cfg.sd_sampler = sampler
  28. assert_equal(
  29. model,
  30. cfg,
  31. f"sdxl_device_{device}.png",
  32. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  33. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  34. fx=2,
  35. fy=2,
  36. )
  37. @pytest.mark.parametrize("device", ["cuda", "cpu"])
  38. @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
  39. @pytest.mark.parametrize("sampler", [SDSampler.ddim])
  40. def test_sdxl_cpu_text_encoder(device, strategy, sampler):
  41. sd_steps = check_device(device)
  42. model = ModelManager(
  43. name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
  44. device=torch.device(device),
  45. disable_nsfw=True,
  46. sd_cpu_textencoder=True,
  47. )
  48. cfg = get_config(
  49. strategy=strategy,
  50. prompt="face of a fox, sitting on a bench",
  51. sd_steps=sd_steps,
  52. sd_strength=1.0,
  53. sd_guidance_scale=7.0,
  54. )
  55. cfg.sd_sampler = sampler
  56. assert_equal(
  57. model,
  58. cfg,
  59. f"sdxl_device_{device}.png",
  60. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  61. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  62. fx=2,
  63. fy=2,
  64. )
  65. @pytest.mark.parametrize("device", ["cuda", "mps"])
  66. @pytest.mark.parametrize(
  67. "rect",
  68. [
  69. [-128, -128, 1024, 1024],
  70. ],
  71. )
  72. def test_sdxl_outpainting(device, rect):
  73. sd_steps = check_device(device)
  74. model = ModelManager(
  75. name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
  76. device=torch.device(device),
  77. disable_nsfw=True,
  78. sd_cpu_textencoder=False,
  79. )
  80. cfg = get_config(
  81. strategy=HDStrategy.ORIGINAL,
  82. prompt="a dog sitting on a bench in the park",
  83. sd_steps=sd_steps,
  84. use_extender=True,
  85. extender_x=rect[0],
  86. extender_y=rect[1],
  87. extender_width=rect[2],
  88. extender_height=rect[3],
  89. sd_strength=1.0,
  90. sd_guidance_scale=8.0,
  91. sd_sampler=SDSampler.ddim,
  92. )
  93. assert_equal(
  94. model,
  95. cfg,
  96. f"sdxl_outpainting_dog_ddim_{'_'.join(map(str, rect))}_device_{device}.png",
  97. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  98. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  99. fx=1.5,
  100. fy=1.5,
  101. )