build_sam.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from functools import partial
  6. import torch
  7. from sorawm.iopaint.plugins.segment_anything.modeling.tiny_vit_sam import TinyViT
  8. from .modeling import (
  9. ImageEncoderViT,
  10. MaskDecoder,
  11. PromptEncoder,
  12. Sam,
  13. TwoWayTransformer,
  14. )
  15. from .modeling.image_encoder_hq import ImageEncoderViTHQ
  16. from .modeling.mask_decoder import MaskDecoderHQ
  17. from .modeling.sam_hq import SamHQ
  18. def build_sam_vit_h(checkpoint=None):
  19. return _build_sam(
  20. encoder_embed_dim=1280,
  21. encoder_depth=32,
  22. encoder_num_heads=16,
  23. encoder_global_attn_indexes=[7, 15, 23, 31],
  24. checkpoint=checkpoint,
  25. )
  26. def build_sam_vit_l(checkpoint=None):
  27. return _build_sam(
  28. encoder_embed_dim=1024,
  29. encoder_depth=24,
  30. encoder_num_heads=16,
  31. encoder_global_attn_indexes=[5, 11, 17, 23],
  32. checkpoint=checkpoint,
  33. )
  34. def build_sam_vit_b(checkpoint=None):
  35. return _build_sam(
  36. encoder_embed_dim=768,
  37. encoder_depth=12,
  38. encoder_num_heads=12,
  39. encoder_global_attn_indexes=[2, 5, 8, 11],
  40. checkpoint=checkpoint,
  41. )
  42. def build_sam_vit_t(checkpoint=None):
  43. prompt_embed_dim = 256
  44. image_size = 1024
  45. vit_patch_size = 16
  46. image_embedding_size = image_size // vit_patch_size
  47. mobile_sam = Sam(
  48. image_encoder=TinyViT(
  49. img_size=1024,
  50. in_chans=3,
  51. num_classes=1000,
  52. embed_dims=[64, 128, 160, 320],
  53. depths=[2, 2, 6, 2],
  54. num_heads=[2, 4, 5, 10],
  55. window_sizes=[7, 7, 14, 7],
  56. mlp_ratio=4.0,
  57. drop_rate=0.0,
  58. drop_path_rate=0.0,
  59. use_checkpoint=False,
  60. mbconv_expand_ratio=4.0,
  61. local_conv_size=3,
  62. layer_lr_decay=0.8,
  63. ),
  64. prompt_encoder=PromptEncoder(
  65. embed_dim=prompt_embed_dim,
  66. image_embedding_size=(image_embedding_size, image_embedding_size),
  67. input_image_size=(image_size, image_size),
  68. mask_in_chans=16,
  69. ),
  70. mask_decoder=MaskDecoder(
  71. num_multimask_outputs=3,
  72. transformer=TwoWayTransformer(
  73. depth=2,
  74. embedding_dim=prompt_embed_dim,
  75. mlp_dim=2048,
  76. num_heads=8,
  77. ),
  78. transformer_dim=prompt_embed_dim,
  79. iou_head_depth=3,
  80. iou_head_hidden_dim=256,
  81. ),
  82. pixel_mean=[123.675, 116.28, 103.53],
  83. pixel_std=[58.395, 57.12, 57.375],
  84. )
  85. mobile_sam.eval()
  86. if checkpoint is not None:
  87. with open(checkpoint, "rb") as f:
  88. state_dict = torch.load(f)
  89. mobile_sam.load_state_dict(state_dict)
  90. return mobile_sam
  91. def build_sam_vit_h_hq(checkpoint=None):
  92. return _build_sam_hq(
  93. encoder_embed_dim=1280,
  94. encoder_depth=32,
  95. encoder_num_heads=16,
  96. encoder_global_attn_indexes=[7, 15, 23, 31],
  97. checkpoint=checkpoint,
  98. )
  99. def build_sam_vit_l_hq(checkpoint=None):
  100. return _build_sam_hq(
  101. encoder_embed_dim=1024,
  102. encoder_depth=24,
  103. encoder_num_heads=16,
  104. encoder_global_attn_indexes=[5, 11, 17, 23],
  105. checkpoint=checkpoint,
  106. )
  107. def build_sam_vit_b_hq(checkpoint=None):
  108. return _build_sam_hq(
  109. encoder_embed_dim=768,
  110. encoder_depth=12,
  111. encoder_num_heads=12,
  112. encoder_global_attn_indexes=[2, 5, 8, 11],
  113. checkpoint=checkpoint,
  114. )
  115. sam_model_registry = {
  116. "default": build_sam_vit_h,
  117. "vit_h": build_sam_vit_h,
  118. "vit_l": build_sam_vit_l,
  119. "vit_b": build_sam_vit_b,
  120. "sam_hq_vit_h": build_sam_vit_h_hq,
  121. "sam_hq_vit_l": build_sam_vit_l_hq,
  122. "sam_hq_vit_b": build_sam_vit_b_hq,
  123. "mobile_sam": build_sam_vit_t,
  124. }
  125. def _build_sam(
  126. encoder_embed_dim,
  127. encoder_depth,
  128. encoder_num_heads,
  129. encoder_global_attn_indexes,
  130. checkpoint=None,
  131. ):
  132. prompt_embed_dim = 256
  133. image_size = 1024
  134. vit_patch_size = 16
  135. image_embedding_size = image_size // vit_patch_size
  136. sam = Sam(
  137. image_encoder=ImageEncoderViT(
  138. depth=encoder_depth,
  139. embed_dim=encoder_embed_dim,
  140. img_size=image_size,
  141. mlp_ratio=4,
  142. norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
  143. num_heads=encoder_num_heads,
  144. patch_size=vit_patch_size,
  145. qkv_bias=True,
  146. use_rel_pos=True,
  147. global_attn_indexes=encoder_global_attn_indexes,
  148. window_size=14,
  149. out_chans=prompt_embed_dim,
  150. ),
  151. prompt_encoder=PromptEncoder(
  152. embed_dim=prompt_embed_dim,
  153. image_embedding_size=(image_embedding_size, image_embedding_size),
  154. input_image_size=(image_size, image_size),
  155. mask_in_chans=16,
  156. ),
  157. mask_decoder=MaskDecoder(
  158. num_multimask_outputs=3,
  159. transformer=TwoWayTransformer(
  160. depth=2,
  161. embedding_dim=prompt_embed_dim,
  162. mlp_dim=2048,
  163. num_heads=8,
  164. ),
  165. transformer_dim=prompt_embed_dim,
  166. iou_head_depth=3,
  167. iou_head_hidden_dim=256,
  168. ),
  169. pixel_mean=[123.675, 116.28, 103.53],
  170. pixel_std=[58.395, 57.12, 57.375],
  171. )
  172. sam.eval()
  173. if checkpoint is not None:
  174. with open(checkpoint, "rb") as f:
  175. state_dict = torch.load(f)
  176. sam.load_state_dict(state_dict)
  177. return sam
  178. def _build_sam_hq(
  179. encoder_embed_dim,
  180. encoder_depth,
  181. encoder_num_heads,
  182. encoder_global_attn_indexes,
  183. checkpoint=None,
  184. ):
  185. prompt_embed_dim = 256
  186. image_size = 1024
  187. vit_patch_size = 16
  188. image_embedding_size = image_size // vit_patch_size
  189. sam = SamHQ(
  190. image_encoder=ImageEncoderViTHQ(
  191. depth=encoder_depth,
  192. embed_dim=encoder_embed_dim,
  193. img_size=image_size,
  194. mlp_ratio=4,
  195. norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
  196. num_heads=encoder_num_heads,
  197. patch_size=vit_patch_size,
  198. qkv_bias=True,
  199. use_rel_pos=True,
  200. global_attn_indexes=encoder_global_attn_indexes,
  201. window_size=14,
  202. out_chans=prompt_embed_dim,
  203. ),
  204. prompt_encoder=PromptEncoder(
  205. embed_dim=prompt_embed_dim,
  206. image_embedding_size=(image_embedding_size, image_embedding_size),
  207. input_image_size=(image_size, image_size),
  208. mask_in_chans=16,
  209. ),
  210. mask_decoder=MaskDecoderHQ(
  211. num_multimask_outputs=3,
  212. transformer=TwoWayTransformer(
  213. depth=2,
  214. embedding_dim=prompt_embed_dim,
  215. mlp_dim=2048,
  216. num_heads=8,
  217. ),
  218. transformer_dim=prompt_embed_dim,
  219. iou_head_depth=3,
  220. iou_head_hidden_dim=256,
  221. vit_dim=encoder_embed_dim,
  222. ),
  223. pixel_mean=[123.675, 116.28, 103.53],
  224. pixel_std=[58.395, 57.12, 57.375],
  225. )
  226. sam.eval()
  227. if checkpoint is not None:
  228. with open(checkpoint, "rb") as f:
  229. device = "cuda" if torch.cuda.is_available() else "cpu"
  230. state_dict = torch.load(f, map_location=device)
  231. info = sam.load_state_dict(state_dict, strict=False)
  232. print(info)
  233. for n, p in sam.named_parameters():
  234. if (
  235. "hf_token" not in n
  236. and "hf_mlp" not in n
  237. and "compress_vit_feat" not in n
  238. and "embedding_encoder" not in n
  239. and "embedding_maskfeature" not in n
  240. ):
  241. p.requires_grad = False
  242. return sam