test_low_mem.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import os
  2. from loguru import logger
  3. from sorawm.iopaint.tests.utils import (
  4. assert_equal,
  5. check_device,
  6. current_dir,
  7. get_config,
  8. )
  9. os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
  10. import pytest
  11. import torch
  12. from sorawm.iopaint.model_manager import ModelManager
  13. from sorawm.iopaint.schema import HDStrategy, SDSampler
  14. @pytest.mark.parametrize("device", ["cuda", "mps"])
  15. def test_runway_sd_1_5_low_mem(device):
  16. sd_steps = check_device(device)
  17. model = ModelManager(
  18. name="runwayml/stable-diffusion-inpainting",
  19. device=torch.device(device),
  20. disable_nsfw=True,
  21. sd_cpu_textencoder=False,
  22. low_mem=True,
  23. )
  24. all_samplers = [member.value for member in SDSampler.__members__.values()]
  25. print(all_samplers)
  26. cfg = get_config(
  27. strategy=HDStrategy.ORIGINAL,
  28. prompt="a fox sitting on a bench",
  29. sd_steps=sd_steps,
  30. sd_sampler=SDSampler.ddim,
  31. )
  32. name = f"device_{device}"
  33. assert_equal(
  34. model,
  35. cfg,
  36. f"runway_sd_{name}_low_mem.png",
  37. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  38. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  39. )
  40. @pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
  41. @pytest.mark.parametrize("sampler", [SDSampler.lcm])
  42. def test_runway_sd_lcm_lora_low_mem(device, sampler):
  43. check_device(device)
  44. sd_steps = 5
  45. model = ModelManager(
  46. name="runwayml/stable-diffusion-inpainting",
  47. device=torch.device(device),
  48. disable_nsfw=True,
  49. sd_cpu_textencoder=False,
  50. low_mem=True,
  51. )
  52. cfg = get_config(
  53. strategy=HDStrategy.ORIGINAL,
  54. prompt="face of a fox, sitting on a bench",
  55. sd_steps=sd_steps,
  56. sd_guidance_scale=2,
  57. sd_lcm_lora=True,
  58. )
  59. cfg.sd_sampler = sampler
  60. assert_equal(
  61. model,
  62. cfg,
  63. f"runway_sd_1_5_lcm_lora_device_{device}_low_mem.png",
  64. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  65. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  66. )
  67. @pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
  68. @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
  69. @pytest.mark.parametrize("sampler", [SDSampler.ddim])
  70. def test_runway_norm_sd_model(device, strategy, sampler):
  71. sd_steps = check_device(device)
  72. model = ModelManager(
  73. name="runwayml/stable-diffusion-v1-5",
  74. device=torch.device(device),
  75. disable_nsfw=True,
  76. sd_cpu_textencoder=False,
  77. low_mem=True,
  78. )
  79. cfg = get_config(
  80. strategy=strategy, prompt="face of a fox, sitting on a bench", sd_steps=sd_steps
  81. )
  82. cfg.sd_sampler = sampler
  83. assert_equal(
  84. model,
  85. cfg,
  86. f"runway_{device}_norm_sd_model_device_{device}_low_mem.png",
  87. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  88. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  89. )