test_sd_model.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import os
  2. from loguru import logger
  3. from sorawm.iopaint.tests.utils import assert_equal, check_device, get_config
  4. os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
  5. from pathlib import Path
  6. import pytest
  7. import torch
  8. from sorawm.iopaint.model_manager import ModelManager
  9. from sorawm.iopaint.schema import HDStrategy, SDSampler
  10. current_dir = Path(__file__).parent.absolute().resolve()
  11. save_dir = current_dir / "result"
  12. save_dir.mkdir(exist_ok=True, parents=True)
  13. @pytest.mark.parametrize("device", ["cuda", "mps"])
  14. def test_runway_sd_1_5_all_samplers(device):
  15. sd_steps = check_device(device)
  16. model = ModelManager(
  17. name="runwayml/stable-diffusion-inpainting",
  18. device=torch.device(device),
  19. disable_nsfw=True,
  20. sd_cpu_textencoder=False,
  21. )
  22. all_samplers = [member.value for member in SDSampler.__members__.values()]
  23. print(all_samplers)
  24. for sampler in all_samplers:
  25. print(f"Testing sampler {sampler}")
  26. if (
  27. sampler
  28. in [SDSampler.dpm2_karras, SDSampler.dpm2_a_karras, SDSampler.lms_karras]
  29. and device == "mps"
  30. ):
  31. # diffusers 0.25.0 still has bug on these sampler on mps, wait main branch released to fix it
  32. logger.warning(
  33. "skip dpm2_karras on mps, diffusers does not support it on mps. TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead."
  34. )
  35. continue
  36. cfg = get_config(
  37. strategy=HDStrategy.ORIGINAL,
  38. prompt="a fox sitting on a bench",
  39. sd_steps=sd_steps,
  40. sd_sampler=sampler,
  41. )
  42. name = f"device_{device}_{sampler}"
  43. assert_equal(
  44. model,
  45. cfg,
  46. f"runway_sd_{name}.png",
  47. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  48. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  49. )
  50. @pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
  51. @pytest.mark.parametrize("sampler", [SDSampler.lcm])
  52. def test_runway_sd_lcm_lora(device, sampler):
  53. check_device(device)
  54. sd_steps = 5
  55. model = ModelManager(
  56. name="runwayml/stable-diffusion-inpainting",
  57. device=torch.device(device),
  58. disable_nsfw=True,
  59. sd_cpu_textencoder=False,
  60. )
  61. cfg = get_config(
  62. strategy=HDStrategy.ORIGINAL,
  63. prompt="face of a fox, sitting on a bench",
  64. sd_steps=sd_steps,
  65. sd_guidance_scale=2,
  66. sd_lcm_lora=True,
  67. )
  68. cfg.sd_sampler = sampler
  69. assert_equal(
  70. model,
  71. cfg,
  72. f"runway_sd_1_5_lcm_lora_device_{device}.png",
  73. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  74. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  75. )
  76. @pytest.mark.parametrize("device", ["cuda", "mps"])
  77. @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
  78. @pytest.mark.parametrize("sampler", [SDSampler.ddim])
  79. def test_runway_sd_sd_strength(device, strategy, sampler):
  80. sd_steps = check_device(device)
  81. model = ModelManager(
  82. name="runwayml/stable-diffusion-inpainting",
  83. device=torch.device(device),
  84. disable_nsfw=True,
  85. sd_cpu_textencoder=False,
  86. )
  87. cfg = get_config(
  88. strategy=strategy,
  89. prompt="a fox sitting on a bench",
  90. sd_steps=sd_steps,
  91. sd_strength=0.8,
  92. )
  93. cfg.sd_sampler = sampler
  94. assert_equal(
  95. model,
  96. cfg,
  97. f"runway_sd_strength_0.8_device_{device}.png",
  98. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  99. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  100. )
  101. @pytest.mark.parametrize("device", ["cuda", "cpu"])
  102. @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
  103. @pytest.mark.parametrize("sampler", [SDSampler.ddim])
  104. def test_runway_sd_cpu_textencoder(device, strategy, sampler):
  105. sd_steps = check_device(device)
  106. model = ModelManager(
  107. name="runwayml/stable-diffusion-inpainting",
  108. device=torch.device(device),
  109. disable_nsfw=True,
  110. sd_cpu_textencoder=True,
  111. )
  112. cfg = get_config(
  113. strategy=strategy,
  114. prompt="a fox sitting on a bench",
  115. sd_steps=sd_steps,
  116. sd_sampler=sampler,
  117. )
  118. assert_equal(
  119. model,
  120. cfg,
  121. f"runway_sd_device_{device}_cpu_textencoder.png",
  122. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  123. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  124. )
  125. @pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
  126. @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
  127. @pytest.mark.parametrize("sampler", [SDSampler.ddim])
  128. def test_runway_norm_sd_model(device, strategy, sampler):
  129. sd_steps = check_device(device)
  130. model = ModelManager(
  131. name="runwayml/stable-diffusion-v1-5",
  132. device=torch.device(device),
  133. disable_nsfw=True,
  134. sd_cpu_textencoder=False,
  135. )
  136. cfg = get_config(
  137. strategy=strategy, prompt="face of a fox, sitting on a bench", sd_steps=sd_steps
  138. )
  139. cfg.sd_sampler = sampler
  140. assert_equal(
  141. model,
  142. cfg,
  143. f"runway_{device}_norm_sd_model_device_{device}.png",
  144. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  145. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  146. )
  147. @pytest.mark.parametrize("device", ["cuda"])
  148. @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
  149. @pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m])
  150. def test_runway_sd_1_5_cpu_offload(device, strategy, sampler):
  151. sd_steps = check_device(device)
  152. model = ModelManager(
  153. name="runwayml/stable-diffusion-inpainting",
  154. device=torch.device(device),
  155. disable_nsfw=True,
  156. sd_cpu_textencoder=False,
  157. cpu_offload=True,
  158. )
  159. cfg = get_config(
  160. strategy=strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps
  161. )
  162. cfg.sd_sampler = sampler
  163. name = f"device_{device}_{sampler}"
  164. assert_equal(
  165. model,
  166. cfg,
  167. f"runway_sd_{strategy.capitalize()}_{name}_cpu_offload.png",
  168. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  169. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  170. )
  171. @pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
  172. @pytest.mark.parametrize("sampler", [SDSampler.ddim])
  173. @pytest.mark.parametrize(
  174. "name",
  175. [
  176. "sd-v1-5-inpainting.safetensors",
  177. "v1-5-pruned-emaonly.safetensors",
  178. "sd_xl_base_1.0.safetensors",
  179. "sd_xl_base_1.0_inpainting_0.1.safetensors",
  180. ],
  181. )
  182. def test_local_file_path(device, sampler, name):
  183. sd_steps = check_device(device)
  184. model = ModelManager(
  185. name=name,
  186. device=torch.device(device),
  187. disable_nsfw=True,
  188. sd_cpu_textencoder=False,
  189. cpu_offload=False,
  190. )
  191. cfg = get_config(
  192. strategy=HDStrategy.ORIGINAL,
  193. prompt="a fox sitting on a bench",
  194. sd_steps=sd_steps,
  195. )
  196. cfg.sd_sampler = sampler
  197. name = f"device_{device}_{sampler}_{name}"
  198. is_sdxl = "sd_xl" in name
  199. assert_equal(
  200. model,
  201. cfg,
  202. f"sd_local_model_{name}.png",
  203. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  204. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  205. fx=1.5 if is_sdxl else 1,
  206. fy=1.5 if is_sdxl else 1,
  207. )