| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- import hashlib
- from typing import List
- import numpy as np
- import torch
- from loguru import logger
- from sorawm.iopaint.helper import download_model
- from sorawm.iopaint.plugins.base_plugin import BasePlugin
- from sorawm.iopaint.plugins.segment_anything import SamPredictor, sam_model_registry
- from sorawm.iopaint.plugins.segment_anything2.build_sam import build_sam2
- from sorawm.iopaint.plugins.segment_anything2.sam2_image_predictor import (
- SAM2ImagePredictor,
- )
- from sorawm.iopaint.plugins.segment_anything.predictor_hq import SamHQPredictor
- from sorawm.iopaint.schema import RunPluginRequest
- # 从小到大
- SEGMENT_ANYTHING_MODELS = {
- "vit_b": {
- "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
- "md5": "01ec64d29a2fca3f0661936605ae66f8",
- },
- "vit_l": {
- "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
- "md5": "0b3195507c641ddb6910d2bb5adee89c",
- },
- "vit_h": {
- "url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
- "md5": "4b8939a88964f0f4ff5f5b2642c598a6",
- },
- "mobile_sam": {
- "url": "https://github.com/Sanster/models/releases/download/MobileSAM/mobile_sam.pt",
- "md5": "f3c0d8cda613564d499310dab6c812cd",
- },
- "sam_hq_vit_b": {
- "url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_b.pth",
- "md5": "c6b8953247bcfdc8bb8ef91e36a6cacc",
- },
- "sam_hq_vit_l": {
- "url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_l.pth",
- "md5": "08947267966e4264fb39523eccc33f86",
- },
- "sam_hq_vit_h": {
- "url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth",
- "md5": "3560f6b6a5a6edacd814a1325c39640a",
- },
- "sam2_tiny": {
- "url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt",
- "md5": "99eacccce4ada0b35153d4fd7af05297",
- },
- "sam2_small": {
- "url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
- "md5": "7f320dbeb497330a2472da5a16c7324d",
- },
- "sam2_base": {
- "url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
- "md5": "09dc5a3d7719f64aaea1d37341ef26f2",
- },
- "sam2_large": {
- "url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
- "md5": "08083462423be3260cd6a5eef94dc01c",
- },
- "sam2_1_tiny": {
- "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt",
- "md5": "6aa6761c9da74fbaa74b4c790a0a2007",
- },
- "sam2_1_small": {
- "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt",
- "md5": "51713b3d1994696d27f35f9c6de6f5ef",
- },
- "sam2_1_base": {
- "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt",
- "md5": "ec7bd7d23d280d5e3cfa45984c02eda5",
- },
- "sam2_1_large": {
- "url": "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt",
- "md5": "2b30654b6112c42a115563c638d238d9",
- },
- }
- class InteractiveSeg(BasePlugin):
- name = "InteractiveSeg"
- support_gen_mask = True
- def __init__(self, model_name, device):
- super().__init__()
- self.model_name = model_name
- self.device = device
- self._init_session(model_name)
- def _init_session(self, model_name: str):
- model_path = download_model(
- SEGMENT_ANYTHING_MODELS[model_name]["url"],
- SEGMENT_ANYTHING_MODELS[model_name]["md5"],
- )
- logger.info(f"SegmentAnything model path: {model_path}")
- if "sam_hq" in model_name:
- self.predictor = SamHQPredictor(
- sam_model_registry[model_name](checkpoint=model_path).to(self.device)
- )
- elif model_name.startswith("sam2"):
- sam2_model = build_sam2(
- model_name, ckpt_path=model_path, device=self.device
- )
- self.predictor = SAM2ImagePredictor(sam2_model)
- else:
- self.predictor = SamPredictor(
- sam_model_registry[model_name](checkpoint=model_path).to(self.device)
- )
- self.prev_img_md5 = None
- def switch_model(self, new_model_name):
- if self.model_name == new_model_name:
- return
- logger.info(
- f"Switching InteractiveSeg model from {self.model_name} to {new_model_name}"
- )
- self._init_session(new_model_name)
- self.model_name = new_model_name
- def gen_mask(self, rgb_np_img, req: RunPluginRequest) -> np.ndarray:
- img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
- return self.forward(rgb_np_img, req.clicks, img_md5)
- @torch.inference_mode()
- def forward(self, rgb_np_img, clicks: List[List], img_md5: str):
- input_point = []
- input_label = []
- for click in clicks:
- x = click[0]
- y = click[1]
- input_point.append([x, y])
- input_label.append(click[2])
- if img_md5 and img_md5 != self.prev_img_md5:
- self.prev_img_md5 = img_md5
- self.predictor.set_image(rgb_np_img)
- masks, _, _ = self.predictor.predict(
- point_coords=np.array(input_point),
- point_labels=np.array(input_label),
- multimask_output=False,
- )
- mask = masks[0].astype(np.uint8) * 255
- return mask
|