test_model.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import pytest
  2. import torch
  3. from sorawm.iopaint.model_manager import ModelManager
  4. from sorawm.iopaint.schema import HDStrategy, LDMSampler
  5. from sorawm.iopaint.tests.utils import (
  6. assert_equal,
  7. check_device,
  8. current_dir,
  9. get_config,
  10. )
  11. @pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
  12. @pytest.mark.parametrize(
  13. "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
  14. )
  15. def test_lama(device, strategy):
  16. check_device(device)
  17. model = ModelManager(name="lama", device=device)
  18. assert_equal(
  19. model,
  20. get_config(strategy=strategy),
  21. f"lama_{strategy[0].upper() + strategy[1:]}_result.png",
  22. )
  23. fx = 1.3
  24. assert_equal(
  25. model,
  26. get_config(strategy=strategy),
  27. f"lama_{strategy[0].upper() + strategy[1:]}_fx_{fx}_result.png",
  28. fx=1.3,
  29. )
  30. @pytest.mark.parametrize("device", ["cuda", "cpu"])
  31. @pytest.mark.parametrize(
  32. "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
  33. )
  34. @pytest.mark.parametrize("ldm_sampler", [LDMSampler.ddim, LDMSampler.plms])
  35. def test_ldm(device, strategy, ldm_sampler):
  36. check_device(device)
  37. model = ModelManager(name="ldm", device=device)
  38. cfg = get_config(strategy=strategy, ldm_sampler=ldm_sampler)
  39. assert_equal(
  40. model, cfg, f"ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_result.png"
  41. )
  42. fx = 1.3
  43. assert_equal(
  44. model,
  45. cfg,
  46. f"ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_fx_{fx}_result.png",
  47. fx=fx,
  48. )
  49. @pytest.mark.parametrize("device", ["cuda", "cpu"])
  50. @pytest.mark.parametrize(
  51. "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
  52. )
  53. @pytest.mark.parametrize("zits_wireframe", [False, True])
  54. def test_zits(device, strategy, zits_wireframe):
  55. check_device(device)
  56. model = ModelManager(name="zits", device=device)
  57. cfg = get_config(strategy=strategy, zits_wireframe=zits_wireframe)
  58. assert_equal(
  59. model,
  60. cfg,
  61. f"zits_{strategy[0].upper() + strategy[1:]}_wireframe_{zits_wireframe}_result.png",
  62. )
  63. fx = 1.3
  64. assert_equal(
  65. model,
  66. cfg,
  67. f"zits_{strategy.capitalize()}_wireframe_{zits_wireframe}_fx_{fx}_result.png",
  68. fx=fx,
  69. )
  70. @pytest.mark.parametrize("device", ["cuda", "cpu"])
  71. @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
  72. @pytest.mark.parametrize("no_half", [True, False])
  73. def test_mat(device, strategy, no_half):
  74. check_device(device)
  75. model = ModelManager(name="mat", device=device, no_half=no_half)
  76. cfg = get_config(strategy=strategy)
  77. assert_equal(
  78. model,
  79. cfg,
  80. f"mat_{strategy.capitalize()}_result.png",
  81. )
  82. @pytest.mark.parametrize("device", ["cuda", "cpu"])
  83. @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
  84. def test_fcf(device, strategy):
  85. check_device(device)
  86. model = ModelManager(name="fcf", device=device)
  87. cfg = get_config(strategy=strategy)
  88. assert_equal(model, cfg, f"fcf_{strategy.capitalize()}_result.png", fx=2, fy=2)
  89. assert_equal(model, cfg, f"fcf_{strategy.capitalize()}_result.png", fx=3.8, fy=2)
  90. @pytest.mark.parametrize(
  91. "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
  92. )
  93. @pytest.mark.parametrize("cv2_flag", ["INPAINT_NS", "INPAINT_TELEA"])
  94. @pytest.mark.parametrize("cv2_radius", [3, 15])
  95. def test_cv2(strategy, cv2_flag, cv2_radius):
  96. model = ModelManager(
  97. name="cv2",
  98. device=torch.device("cpu"),
  99. )
  100. cfg = get_config(strategy=strategy, cv2_flag=cv2_flag, cv2_radius=cv2_radius)
  101. assert_equal(
  102. model,
  103. cfg,
  104. f"cv2_{strategy.capitalize()}_{cv2_flag}_{cv2_radius}.png",
  105. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  106. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  107. )
  108. @pytest.mark.parametrize("device", ["cuda", "cpu"])
  109. @pytest.mark.parametrize(
  110. "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
  111. )
  112. def test_manga(device, strategy):
  113. check_device(device)
  114. model = ModelManager(
  115. name="manga",
  116. device=torch.device(device),
  117. )
  118. cfg = get_config(strategy=strategy)
  119. assert_equal(
  120. model,
  121. cfg,
  122. f"manga_{strategy.capitalize()}.png",
  123. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  124. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  125. )
  126. @pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
  127. @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
  128. def test_mi_gan(device, strategy):
  129. check_device(device)
  130. model = ModelManager(
  131. name="migan",
  132. device=torch.device(device),
  133. )
  134. cfg = get_config(strategy=strategy)
  135. assert_equal(
  136. model,
  137. cfg,
  138. f"migan_device_{device}.png",
  139. img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
  140. mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
  141. fx=1.5,
  142. fy=1.7,
  143. )