| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- from functools import partial
- import torch
- from sorawm.iopaint.plugins.segment_anything.modeling.tiny_vit_sam import TinyViT
- from .modeling import (
- ImageEncoderViT,
- MaskDecoder,
- PromptEncoder,
- Sam,
- TwoWayTransformer,
- )
- from .modeling.image_encoder_hq import ImageEncoderViTHQ
- from .modeling.mask_decoder import MaskDecoderHQ
- from .modeling.sam_hq import SamHQ
- def build_sam_vit_h(checkpoint=None):
- return _build_sam(
- encoder_embed_dim=1280,
- encoder_depth=32,
- encoder_num_heads=16,
- encoder_global_attn_indexes=[7, 15, 23, 31],
- checkpoint=checkpoint,
- )
- def build_sam_vit_l(checkpoint=None):
- return _build_sam(
- encoder_embed_dim=1024,
- encoder_depth=24,
- encoder_num_heads=16,
- encoder_global_attn_indexes=[5, 11, 17, 23],
- checkpoint=checkpoint,
- )
- def build_sam_vit_b(checkpoint=None):
- return _build_sam(
- encoder_embed_dim=768,
- encoder_depth=12,
- encoder_num_heads=12,
- encoder_global_attn_indexes=[2, 5, 8, 11],
- checkpoint=checkpoint,
- )
- def build_sam_vit_t(checkpoint=None):
- prompt_embed_dim = 256
- image_size = 1024
- vit_patch_size = 16
- image_embedding_size = image_size // vit_patch_size
- mobile_sam = Sam(
- image_encoder=TinyViT(
- img_size=1024,
- in_chans=3,
- num_classes=1000,
- embed_dims=[64, 128, 160, 320],
- depths=[2, 2, 6, 2],
- num_heads=[2, 4, 5, 10],
- window_sizes=[7, 7, 14, 7],
- mlp_ratio=4.0,
- drop_rate=0.0,
- drop_path_rate=0.0,
- use_checkpoint=False,
- mbconv_expand_ratio=4.0,
- local_conv_size=3,
- layer_lr_decay=0.8,
- ),
- prompt_encoder=PromptEncoder(
- embed_dim=prompt_embed_dim,
- image_embedding_size=(image_embedding_size, image_embedding_size),
- input_image_size=(image_size, image_size),
- mask_in_chans=16,
- ),
- mask_decoder=MaskDecoder(
- num_multimask_outputs=3,
- transformer=TwoWayTransformer(
- depth=2,
- embedding_dim=prompt_embed_dim,
- mlp_dim=2048,
- num_heads=8,
- ),
- transformer_dim=prompt_embed_dim,
- iou_head_depth=3,
- iou_head_hidden_dim=256,
- ),
- pixel_mean=[123.675, 116.28, 103.53],
- pixel_std=[58.395, 57.12, 57.375],
- )
- mobile_sam.eval()
- if checkpoint is not None:
- with open(checkpoint, "rb") as f:
- state_dict = torch.load(f)
- mobile_sam.load_state_dict(state_dict)
- return mobile_sam
- def build_sam_vit_h_hq(checkpoint=None):
- return _build_sam_hq(
- encoder_embed_dim=1280,
- encoder_depth=32,
- encoder_num_heads=16,
- encoder_global_attn_indexes=[7, 15, 23, 31],
- checkpoint=checkpoint,
- )
- def build_sam_vit_l_hq(checkpoint=None):
- return _build_sam_hq(
- encoder_embed_dim=1024,
- encoder_depth=24,
- encoder_num_heads=16,
- encoder_global_attn_indexes=[5, 11, 17, 23],
- checkpoint=checkpoint,
- )
- def build_sam_vit_b_hq(checkpoint=None):
- return _build_sam_hq(
- encoder_embed_dim=768,
- encoder_depth=12,
- encoder_num_heads=12,
- encoder_global_attn_indexes=[2, 5, 8, 11],
- checkpoint=checkpoint,
- )
- sam_model_registry = {
- "default": build_sam_vit_h,
- "vit_h": build_sam_vit_h,
- "vit_l": build_sam_vit_l,
- "vit_b": build_sam_vit_b,
- "sam_hq_vit_h": build_sam_vit_h_hq,
- "sam_hq_vit_l": build_sam_vit_l_hq,
- "sam_hq_vit_b": build_sam_vit_b_hq,
- "mobile_sam": build_sam_vit_t,
- }
- def _build_sam(
- encoder_embed_dim,
- encoder_depth,
- encoder_num_heads,
- encoder_global_attn_indexes,
- checkpoint=None,
- ):
- prompt_embed_dim = 256
- image_size = 1024
- vit_patch_size = 16
- image_embedding_size = image_size // vit_patch_size
- sam = Sam(
- image_encoder=ImageEncoderViT(
- depth=encoder_depth,
- embed_dim=encoder_embed_dim,
- img_size=image_size,
- mlp_ratio=4,
- norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
- num_heads=encoder_num_heads,
- patch_size=vit_patch_size,
- qkv_bias=True,
- use_rel_pos=True,
- global_attn_indexes=encoder_global_attn_indexes,
- window_size=14,
- out_chans=prompt_embed_dim,
- ),
- prompt_encoder=PromptEncoder(
- embed_dim=prompt_embed_dim,
- image_embedding_size=(image_embedding_size, image_embedding_size),
- input_image_size=(image_size, image_size),
- mask_in_chans=16,
- ),
- mask_decoder=MaskDecoder(
- num_multimask_outputs=3,
- transformer=TwoWayTransformer(
- depth=2,
- embedding_dim=prompt_embed_dim,
- mlp_dim=2048,
- num_heads=8,
- ),
- transformer_dim=prompt_embed_dim,
- iou_head_depth=3,
- iou_head_hidden_dim=256,
- ),
- pixel_mean=[123.675, 116.28, 103.53],
- pixel_std=[58.395, 57.12, 57.375],
- )
- sam.eval()
- if checkpoint is not None:
- with open(checkpoint, "rb") as f:
- state_dict = torch.load(f)
- sam.load_state_dict(state_dict)
- return sam
- def _build_sam_hq(
- encoder_embed_dim,
- encoder_depth,
- encoder_num_heads,
- encoder_global_attn_indexes,
- checkpoint=None,
- ):
- prompt_embed_dim = 256
- image_size = 1024
- vit_patch_size = 16
- image_embedding_size = image_size // vit_patch_size
- sam = SamHQ(
- image_encoder=ImageEncoderViTHQ(
- depth=encoder_depth,
- embed_dim=encoder_embed_dim,
- img_size=image_size,
- mlp_ratio=4,
- norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
- num_heads=encoder_num_heads,
- patch_size=vit_patch_size,
- qkv_bias=True,
- use_rel_pos=True,
- global_attn_indexes=encoder_global_attn_indexes,
- window_size=14,
- out_chans=prompt_embed_dim,
- ),
- prompt_encoder=PromptEncoder(
- embed_dim=prompt_embed_dim,
- image_embedding_size=(image_embedding_size, image_embedding_size),
- input_image_size=(image_size, image_size),
- mask_in_chans=16,
- ),
- mask_decoder=MaskDecoderHQ(
- num_multimask_outputs=3,
- transformer=TwoWayTransformer(
- depth=2,
- embedding_dim=prompt_embed_dim,
- mlp_dim=2048,
- num_heads=8,
- ),
- transformer_dim=prompt_embed_dim,
- iou_head_depth=3,
- iou_head_hidden_dim=256,
- vit_dim=encoder_embed_dim,
- ),
- pixel_mean=[123.675, 116.28, 103.53],
- pixel_std=[58.395, 57.12, 57.375],
- )
- sam.eval()
- if checkpoint is not None:
- with open(checkpoint, "rb") as f:
- device = "cuda" if torch.cuda.is_available() else "cpu"
- state_dict = torch.load(f, map_location=device)
- info = sam.load_state_dict(state_dict, strict=False)
- print(info)
- for n, p in sam.named_parameters():
- if (
- "hf_token" not in n
- and "hf_mlp" not in n
- and "compress_vit_feat" not in n
- and "embedding_encoder" not in n
- and "embedding_maskfeature" not in n
- ):
- p.requires_grad = False
- return sam
|