remove_bg.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import os
  2. import cv2
  3. import numpy as np
  4. import torch
  5. from loguru import logger
  6. from torch.hub import get_dir
  7. from sorawm.iopaint.plugins.base_plugin import BasePlugin
  8. from sorawm.iopaint.schema import Device, RemoveBGModel, RunPluginRequest
  9. def _rmbg_remove(device, *args, **kwargs):
  10. from rembg import remove
  11. return remove(*args, **kwargs)
  12. class RemoveBG(BasePlugin):
  13. name = "RemoveBG"
  14. support_gen_mask = True
  15. support_gen_image = True
  16. def __init__(self, model_name, device):
  17. super().__init__()
  18. self.model_name = model_name
  19. self.device = device
  20. if model_name.startswith("birefnet"):
  21. import rembg
  22. if rembg.__version__ < "2.0.59":
  23. raise ValueError(
  24. "To use birefnet models, please upgrade rembg to >= 2.0.59. pip install -U rembg"
  25. )
  26. hub_dir = get_dir()
  27. model_dir = os.path.join(hub_dir, "checkpoints")
  28. os.environ["U2NET_HOME"] = model_dir
  29. self._init_session(model_name)
  30. def _init_session(self, model_name: str):
  31. self.device_warning()
  32. if model_name == RemoveBGModel.briaai_rmbg_1_4:
  33. from sorawm.iopaint.plugins.briarmbg import (
  34. briarmbg_process,
  35. create_briarmbg_session,
  36. )
  37. self.session = create_briarmbg_session().to(self.device)
  38. self.remove = briarmbg_process
  39. elif model_name == RemoveBGModel.briaai_rmbg_2_0:
  40. from sorawm.iopaint.plugins.briarmbg2 import (
  41. briarmbg2_process,
  42. create_briarmbg2_session,
  43. )
  44. self.session = create_briarmbg2_session().to(self.device)
  45. self.remove = briarmbg2_process
  46. else:
  47. from rembg import new_session
  48. self.session = new_session(model_name=model_name)
  49. self.remove = _rmbg_remove
  50. def switch_model(self, new_model_name):
  51. if self.model_name == new_model_name:
  52. return
  53. logger.info(
  54. f"Switching removebg model from {self.model_name} to {new_model_name}"
  55. )
  56. self._init_session(new_model_name)
  57. self.model_name = new_model_name
  58. @torch.inference_mode()
  59. def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
  60. bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
  61. # return BGRA image
  62. output = self.remove(self.device, bgr_np_img, session=self.session)
  63. return cv2.cvtColor(output, cv2.COLOR_BGRA2RGBA)
  64. @torch.inference_mode()
  65. def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
  66. bgr_np_img = cv2.cvtColor(rgb_np_img, cv2.COLOR_RGB2BGR)
  67. # return BGR image, 255 means foreground, 0 means background
  68. output = self.remove(
  69. self.device, bgr_np_img, session=self.session, only_mask=True
  70. )
  71. return output
  72. def check_dep(self):
  73. try:
  74. import rembg
  75. except ImportError as e:
  76. import traceback
  77. error_msg = traceback.format_exc()
  78. return f"Install rembg failed, Error details:\n{error_msg}"
  79. def device_warning(self):
  80. if self.device == Device.cuda and self.model_name not in [
  81. RemoveBGModel.briaai_rmbg_1_4,
  82. RemoveBGModel.briaai_rmbg_2_0,
  83. ]:
  84. logger.warning(
  85. f"remove_bg_device=cuda only supports briaai models({RemoveBGModel.briaai_rmbg_1_4.value}/{RemoveBGModel.briaai_rmbg_2_0.value})"
  86. )