__init__.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from typing import Dict
  2. from loguru import logger
  3. from ..schema import Device, InteractiveSegModel, RealESRGANModel
  4. from .anime_seg import AnimeSeg
  5. from .gfpgan_plugin import GFPGANPlugin
  6. from .interactive_seg import InteractiveSeg
  7. from .realesrgan import RealESRGANUpscaler
  8. from .remove_bg import RemoveBG
  9. from .restoreformer import RestoreFormerPlugin
  10. def build_plugins(
  11. enable_interactive_seg: bool,
  12. interactive_seg_model: InteractiveSegModel,
  13. interactive_seg_device: Device,
  14. enable_remove_bg: bool,
  15. remove_bg_device: Device,
  16. remove_bg_model: str,
  17. enable_anime_seg: bool,
  18. enable_realesrgan: bool,
  19. realesrgan_device: Device,
  20. realesrgan_model: RealESRGANModel,
  21. enable_gfpgan: bool,
  22. gfpgan_device: Device,
  23. enable_restoreformer: bool,
  24. restoreformer_device: Device,
  25. no_half: bool,
  26. ) -> Dict:
  27. plugins = {}
  28. if enable_interactive_seg:
  29. logger.info(f"Initialize {InteractiveSeg.name} plugin")
  30. plugins[InteractiveSeg.name] = InteractiveSeg(
  31. interactive_seg_model, interactive_seg_device
  32. )
  33. if enable_remove_bg:
  34. logger.info(f"Initialize {RemoveBG.name} plugin")
  35. plugins[RemoveBG.name] = RemoveBG(remove_bg_model, remove_bg_device)
  36. if enable_anime_seg:
  37. logger.info(f"Initialize {AnimeSeg.name} plugin")
  38. plugins[AnimeSeg.name] = AnimeSeg()
  39. if enable_realesrgan:
  40. logger.info(
  41. f"Initialize {RealESRGANUpscaler.name} plugin: {realesrgan_model}, {realesrgan_device}"
  42. )
  43. plugins[RealESRGANUpscaler.name] = RealESRGANUpscaler(
  44. realesrgan_model,
  45. realesrgan_device,
  46. no_half=no_half,
  47. )
  48. if enable_gfpgan:
  49. logger.info(f"Initialize {GFPGANPlugin.name} plugin")
  50. if enable_realesrgan:
  51. logger.info("Use realesrgan as GFPGAN background upscaler")
  52. else:
  53. logger.info(
  54. f"GFPGAN no background upscaler, use --enable-realesrgan to enable it"
  55. )
  56. plugins[GFPGANPlugin.name] = GFPGANPlugin(
  57. gfpgan_device,
  58. upscaler=plugins.get(RealESRGANUpscaler.name, None),
  59. )
  60. if enable_restoreformer:
  61. logger.info(f"Initialize {RestoreFormerPlugin.name} plugin")
  62. plugins[RestoreFormerPlugin.name] = RestoreFormerPlugin(
  63. restoreformer_device,
  64. upscaler=plugins.get(RealESRGANUpscaler.name, None),
  65. )
  66. return plugins