interactive_seg.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import hashlib
  2. from typing import List
  3. import numpy as np
  4. import torch
  5. from loguru import logger
  6. from sorawm.iopaint.helper import download_model
  7. from sorawm.iopaint.plugins.base_plugin import BasePlugin
  8. from sorawm.iopaint.plugins.segment_anything import SamPredictor, sam_model_registry
  9. from sorawm.iopaint.plugins.segment_anything2.build_sam import build_sam2
  10. from sorawm.iopaint.plugins.segment_anything2.sam2_image_predictor import (
  11. SAM2ImagePredictor,
  12. )
  13. from sorawm.iopaint.plugins.segment_anything.predictor_hq import SamHQPredictor
  14. from sorawm.iopaint.schema import RunPluginRequest
  15. # 从小到大
  16. SEGMENT_ANYTHING_MODELS = {
  17. "vit_b": {
  18. "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
  19. "md5": "01ec64d29a2fca3f0661936605ae66f8",
  20. },
  21. "vit_l": {
  22. "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
  23. "md5": "0b3195507c641ddb6910d2bb5adee89c",
  24. },
  25. "vit_h": {
  26. "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
  27. "md5": "4b8939a88964f0f4ff5f5b2642c598a6",
  28. },
  29. "mobile_sam": {
  30. "url": "https://github.com/Sanster/models/releases/download/MobileSAM/mobile_sam.pt",
  31. "md5": "f3c0d8cda613564d499310dab6c812cd",
  32. },
  33. "sam_hq_vit_b": {
  34. "url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth",
  35. "md5": "c6b8953247bcfdc8bb8ef91e36a6cacc",
  36. },
  37. "sam_hq_vit_l": {
  38. "url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth",
  39. "md5": "08947267966e4264fb39523eccc33f86",
  40. },
  41. "sam_hq_vit_h": {
  42. "url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth",
  43. "md5": "3560f6b6a5a6edacd814a1325c39640a",
  44. },
  45. "sam2_tiny": {
  46. "url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt",
  47. "md5": "99eacccce4ada0b35153d4fd7af05297",
  48. },
  49. "sam2_small": {
  50. "url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
  51. "md5": "7f320dbeb497330a2472da5a16c7324d",
  52. },
  53. "sam2_base": {
  54. "url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
  55. "md5": "09dc5a3d7719f64aaea1d37341ef26f2",
  56. },
  57. "sam2_large": {
  58. "url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
  59. "md5": "08083462423be3260cd6a5eef94dc01c",
  60. },
  61. "sam2_1_tiny": {
  62. "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
  63. "md5": "6aa6761c9da74fbaa74b4c790a0a2007",
  64. },
  65. "sam2_1_small": {
  66. "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
  67. "md5": "51713b3d1994696d27f35f9c6de6f5ef",
  68. },
  69. "sam2_1_base": {
  70. "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
  71. "md5": "ec7bd7d23d280d5e3cfa45984c02eda5",
  72. },
  73. "sam2_1_large": {
  74. "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
  75. "md5": "2b30654b6112c42a115563c638d238d9",
  76. },
  77. }
  78. class InteractiveSeg(BasePlugin):
  79. name = "InteractiveSeg"
  80. support_gen_mask = True
  81. def __init__(self, model_name, device):
  82. super().__init__()
  83. self.model_name = model_name
  84. self.device = device
  85. self._init_session(model_name)
  86. def _init_session(self, model_name: str):
  87. model_path = download_model(
  88. SEGMENT_ANYTHING_MODELS[model_name]["url"],
  89. SEGMENT_ANYTHING_MODELS[model_name]["md5"],
  90. )
  91. logger.info(f"SegmentAnything model path: {model_path}")
  92. if "sam_hq" in model_name:
  93. self.predictor = SamHQPredictor(
  94. sam_model_registry[model_name](checkpoint=model_path).to(self.device)
  95. )
  96. elif model_name.startswith("sam2"):
  97. sam2_model = build_sam2(
  98. model_name, ckpt_path=model_path, device=self.device
  99. )
  100. self.predictor = SAM2ImagePredictor(sam2_model)
  101. else:
  102. self.predictor = SamPredictor(
  103. sam_model_registry[model_name](checkpoint=model_path).to(self.device)
  104. )
  105. self.prev_img_md5 = None
  106. def switch_model(self, new_model_name):
  107. if self.model_name == new_model_name:
  108. return
  109. logger.info(
  110. f"Switching InteractiveSeg model from {self.model_name} to {new_model_name}"
  111. )
  112. self._init_session(new_model_name)
  113. self.model_name = new_model_name
  114. def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
  115. img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
  116. return self.forward(rgb_np_img, req.clicks, img_md5)
  117. @torch.inference_mode()
  118. def forward(self, rgb_np_img, clicks: List[List], img_md5: str):
  119. input_point = []
  120. input_label = []
  121. for click in clicks:
  122. x = click[0]
  123. y = click[1]
  124. input_point.append([x, y])
  125. input_label.append(click[2])
  126. if img_md5 and img_md5 != self.prev_img_md5:
  127. self.prev_img_md5 = img_md5
  128. self.predictor.set_image(rgb_np_img)
  129. masks, _, _ = self.predictor.predict(
  130. point_coords=np.array(input_point),
  131. point_labels=np.array(input_label),
  132. multimask_output=False,
  133. )
  134. mask = masks[0].astype(np.uint8) * 255
  135. return mask