test_instruct_pix2pix.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. from pathlib import Path
  2. import pytest
  3. import torch
  4. from sorawm.iopaint.model_manager import ModelManager
  5. from sorawm.iopaint.schema import HDStrategy
  6. from sorawm.iopaint.tests.utils import (
  7. assert_equal,
  8. check_device,
  9. current_dir,
  10. get_config,
  11. )
  12. model_name = "timbrooks/instruct-pix2pix"
  13. @pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
  14. @pytest.mark.parametrize("disable_nsfw", [True, False])
  15. @pytest.mark.parametrize("cpu_offload", [False, True])
  16. def test_instruct_pix2pix(device, disable_nsfw, cpu_offload):
  17. sd_steps = check_device(device)
  18. model = ModelManager(
  19. name=model_name,
  20. device=torch.device(device),
  21. disable_nsfw=disable_nsfw,
  22. sd_cpu_textencoder=False,
  23. cpu_offload=cpu_offload,
  24. )
  25. cfg = get_config(
  26. strategy=HDStrategy.ORIGINAL,
  27. prompt="What if it were snowing?",
  28. sd_steps=sd_steps,
  29. )
  30. name = f"device_{device}_disnsfw_{disable_nsfw}_cpu_offload_{cpu_offload}"
  31. assert_equal(
  32. model,
  33. cfg,
  34. f"instruct_pix2pix_{name}.png",
  35. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  36. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  37. fx=1.3,
  38. )