base_plugin.py 737 B

123456789101112131415161718192021222324252627282930
  1. import numpy as np
  2. from loguru import logger
  3. from sorawm.iopaint.schema import RunPluginRequest
  4. class BasePlugin:
  5. name: str
  6. support_gen_image: bool = False
  7. support_gen_mask: bool = False
  8. def __init__(self):
  9. err_msg = self.check_dep()
  10. if err_msg:
  11. logger.error(err_msg)
  12. exit(-1)
  13. def gen_image(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
  14. # return RGBA np image or BGR np image
  15. ...
  16. def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
  17. # return GRAY or BGR np image, 255 means foreground, 0 means background
  18. ...
  19. def check_dep(self):
  20. ...
  21. def switch_model(self, new_model_name: str):
  22. ...