test_brushnet.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import os
  2. from sorawm.iopaint.const import SD_BRUSHNET_CHOICES
  3. from sorawm.iopaint.tests.utils import assert_equal, check_device, get_config
  4. os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
  5. from pathlib import Path
  6. import pytest
  7. import torch
  8. from sorawm.iopaint.model_manager import ModelManager
  9. from sorawm.iopaint.schema import HDStrategy, PowerPaintTask, SDSampler
  10. current_dir = Path(__file__).parent.absolute().resolve()
  11. save_dir = current_dir / "result"
  12. save_dir.mkdir(exist_ok=True, parents=True)
  13. @pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
  14. @pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m_karras])
  15. def test_runway_brushnet(device, sampler):
  16. sd_steps = check_device(device)
  17. model = ModelManager(
  18. name="runwayml/stable-diffusion-v1-5",
  19. device=torch.device(device),
  20. disable_nsfw=True,
  21. sd_cpu_textencoder=False,
  22. )
  23. cfg = get_config(
  24. strategy=HDStrategy.ORIGINAL,
  25. prompt="face of a fox, sitting on a bench",
  26. sd_steps=sd_steps,
  27. sd_guidance_scale=7.5,
  28. enable_brushnet=True,
  29. brushnet_method=SD_BRUSHNET_CHOICES[0],
  30. )
  31. cfg.sd_sampler = sampler
  32. assert_equal(
  33. model,
  34. cfg,
  35. f"brushnet_random_mask_{device}.png",
  36. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  37. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  38. )
  39. @pytest.mark.parametrize("device", ["cuda", "mps"])
  40. @pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m])
  41. def test_runway_powerpaint_v2(device, sampler):
  42. sd_steps = check_device(device)
  43. model = ModelManager(
  44. name="runwayml/stable-diffusion-v1-5",
  45. device=torch.device(device),
  46. disable_nsfw=True,
  47. sd_cpu_textencoder=False,
  48. )
  49. tasks = {
  50. PowerPaintTask.text_guided: {
  51. "prompt": "face of a fox, sitting on a bench",
  52. "scale": 7.5,
  53. },
  54. PowerPaintTask.context_aware: {
  55. "prompt": "face of a fox, sitting on a bench",
  56. "scale": 7.5,
  57. },
  58. PowerPaintTask.shape_guided: {
  59. "prompt": "face of a fox, sitting on a bench",
  60. "scale": 7.5,
  61. },
  62. PowerPaintTask.object_remove: {
  63. "prompt": "",
  64. "scale": 12,
  65. },
  66. PowerPaintTask.outpainting: {
  67. "prompt": "",
  68. "scale": 7.5,
  69. },
  70. }
  71. for task, data in tasks.items():
  72. cfg = get_config(
  73. strategy=HDStrategy.ORIGINAL,
  74. prompt=data["prompt"],
  75. negative_prompt="out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature",
  76. sd_steps=sd_steps,
  77. sd_guidance_scale=data["scale"],
  78. enable_powerpaint_v2=True,
  79. powerpaint_task=task,
  80. sd_sampler=sampler,
  81. sd_mask_blur=11,
  82. sd_seed=42,
  83. # sd_keep_unmasked_area=False
  84. )
  85. if task == PowerPaintTask.outpainting:
  86. cfg.use_extender = True
  87. cfg.extender_x = -128
  88. cfg.extender_y = -128
  89. cfg.extender_width = 768
  90. cfg.extender_height = 768
  91. assert_equal(
  92. model,
  93. cfg,
  94. f"powerpaint_v2_{device}_{task}.png",
  95. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  96. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  97. )