mask_decoder.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from typing import List, Tuple, Type
  6. import torch
  7. from torch import nn
  8. from torch.nn import functional as F
  9. from .common import LayerNorm2d
  10. class MaskDecoder(nn.Module):
  11. def __init__(
  12. self,
  13. *,
  14. transformer_dim: int,
  15. transformer: nn.Module,
  16. num_multimask_outputs: int = 3,
  17. activation: Type[nn.Module] = nn.GELU,
  18. iou_head_depth: int = 3,
  19. iou_head_hidden_dim: int = 256,
  20. ) -> None:
  21. """
  22. Predicts masks given an image and prompt embeddings, using a
  23. tranformer architecture.
  24. Arguments:
  25. transformer_dim (int): the channel dimension of the transformer
  26. transformer (nn.Module): the transformer used to predict masks
  27. num_multimask_outputs (int): the number of masks to predict
  28. when disambiguating masks
  29. activation (nn.Module): the type of activation to use when
  30. upscaling masks
  31. iou_head_depth (int): the depth of the MLP used to predict
  32. mask quality
  33. iou_head_hidden_dim (int): the hidden dimension of the MLP
  34. used to predict mask quality
  35. """
  36. super().__init__()
  37. self.transformer_dim = transformer_dim
  38. self.transformer = transformer
  39. self.num_multimask_outputs = num_multimask_outputs
  40. self.iou_token = nn.Embedding(1, transformer_dim)
  41. self.num_mask_tokens = num_multimask_outputs + 1
  42. self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
  43. self.output_upscaling = nn.Sequential(
  44. nn.ConvTranspose2d(
  45. transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
  46. ),
  47. LayerNorm2d(transformer_dim // 4),
  48. activation(),
  49. nn.ConvTranspose2d(
  50. transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
  51. ),
  52. activation(),
  53. )
  54. self.output_hypernetworks_mlps = nn.ModuleList(
  55. [
  56. MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
  57. for i in range(self.num_mask_tokens)
  58. ]
  59. )
  60. self.iou_prediction_head = MLP(
  61. transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
  62. )
  63. def forward(
  64. self,
  65. image_embeddings: torch.Tensor,
  66. image_pe: torch.Tensor,
  67. sparse_prompt_embeddings: torch.Tensor,
  68. dense_prompt_embeddings: torch.Tensor,
  69. multimask_output: bool,
  70. ) -> Tuple[torch.Tensor, torch.Tensor]:
  71. """
  72. Predict masks given image and prompt embeddings.
  73. Arguments:
  74. image_embeddings (torch.Tensor): the embeddings from the image encoder
  75. image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
  76. sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
  77. dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
  78. multimask_output (bool): Whether to return multiple masks or a single
  79. mask.
  80. Returns:
  81. torch.Tensor: batched predicted masks
  82. torch.Tensor: batched predictions of mask quality
  83. """
  84. masks, iou_pred = self.predict_masks(
  85. image_embeddings=image_embeddings,
  86. image_pe=image_pe,
  87. sparse_prompt_embeddings=sparse_prompt_embeddings,
  88. dense_prompt_embeddings=dense_prompt_embeddings,
  89. )
  90. # Select the correct mask or masks for outptu
  91. if multimask_output:
  92. mask_slice = slice(1, None)
  93. else:
  94. mask_slice = slice(0, 1)
  95. masks = masks[:, mask_slice, :, :]
  96. iou_pred = iou_pred[:, mask_slice]
  97. # Prepare output
  98. return masks, iou_pred
  99. def predict_masks(
  100. self,
  101. image_embeddings: torch.Tensor,
  102. image_pe: torch.Tensor,
  103. sparse_prompt_embeddings: torch.Tensor,
  104. dense_prompt_embeddings: torch.Tensor,
  105. ) -> Tuple[torch.Tensor, torch.Tensor]:
  106. """Predicts masks. See 'forward' for more details."""
  107. # Concatenate output tokens
  108. output_tokens = torch.cat(
  109. [self.iou_token.weight, self.mask_tokens.weight], dim=0
  110. )
  111. output_tokens = output_tokens.unsqueeze(0).expand(
  112. sparse_prompt_embeddings.size(0), -1, -1
  113. )
  114. tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
  115. # Expand per-image data in batch direction to be per-mask
  116. src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
  117. src = src + dense_prompt_embeddings
  118. pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
  119. b, c, h, w = src.shape
  120. # Run the transformer
  121. hs, src = self.transformer(src, pos_src, tokens)
  122. iou_token_out = hs[:, 0, :]
  123. mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
  124. # Upscale mask embeddings and predict masks using the mask tokens
  125. src = src.transpose(1, 2).view(b, c, h, w)
  126. upscaled_embedding = self.output_upscaling(src)
  127. hyper_in_list: List[torch.Tensor] = []
  128. for i in range(self.num_mask_tokens):
  129. hyper_in_list.append(
  130. self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
  131. )
  132. hyper_in = torch.stack(hyper_in_list, dim=1)
  133. b, c, h, w = upscaled_embedding.shape
  134. masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
  135. # Generate mask quality predictions
  136. iou_pred = self.iou_prediction_head(iou_token_out)
  137. return masks, iou_pred
  138. # https://github.com/SysCV/sam-hq/blob/main/segment_anything/modeling/mask_decoder_hq.py#L17
  139. class MaskDecoderHQ(nn.Module):
  140. def __init__(
  141. self,
  142. *,
  143. transformer_dim: int,
  144. transformer: nn.Module,
  145. num_multimask_outputs: int = 3,
  146. activation: Type[nn.Module] = nn.GELU,
  147. iou_head_depth: int = 3,
  148. iou_head_hidden_dim: int = 256,
  149. vit_dim: int = 1024,
  150. ) -> None:
  151. """
  152. Predicts masks given an image and prompt embeddings, using a
  153. transformer architecture.
  154. Arguments:
  155. transformer_dim (int): the channel dimension of the transformer
  156. transformer (nn.Module): the transformer used to predict masks
  157. num_multimask_outputs (int): the number of masks to predict
  158. when disambiguating masks
  159. activation (nn.Module): the type of activation to use when
  160. upscaling masks
  161. iou_head_depth (int): the depth of the MLP used to predict
  162. mask quality
  163. iou_head_hidden_dim (int): the hidden dimension of the MLP
  164. used to predict mask quality
  165. """
  166. super().__init__()
  167. self.transformer_dim = transformer_dim
  168. self.transformer = transformer
  169. self.num_multimask_outputs = num_multimask_outputs
  170. self.iou_token = nn.Embedding(1, transformer_dim)
  171. self.num_mask_tokens = num_multimask_outputs + 1
  172. self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
  173. self.output_upscaling = nn.Sequential(
  174. nn.ConvTranspose2d(
  175. transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
  176. ),
  177. LayerNorm2d(transformer_dim // 4),
  178. activation(),
  179. nn.ConvTranspose2d(
  180. transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
  181. ),
  182. activation(),
  183. )
  184. self.output_hypernetworks_mlps = nn.ModuleList(
  185. [
  186. MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
  187. for i in range(self.num_mask_tokens)
  188. ]
  189. )
  190. self.iou_prediction_head = MLP(
  191. transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth
  192. )
  193. # HQ-SAM parameters
  194. self.hf_token = nn.Embedding(1, transformer_dim) # HQ-Ouptput-Token
  195. self.hf_mlp = MLP(
  196. transformer_dim, transformer_dim, transformer_dim // 8, 3
  197. ) # corresponding new MLP layer for HQ-Ouptput-Token
  198. self.num_mask_tokens = self.num_mask_tokens + 1
  199. # three conv fusion layers for obtaining HQ-Feature
  200. self.compress_vit_feat = nn.Sequential(
  201. nn.ConvTranspose2d(vit_dim, transformer_dim, kernel_size=2, stride=2),
  202. LayerNorm2d(transformer_dim),
  203. nn.GELU(),
  204. nn.ConvTranspose2d(
  205. transformer_dim, transformer_dim // 8, kernel_size=2, stride=2
  206. ),
  207. )
  208. self.embedding_encoder = nn.Sequential(
  209. nn.ConvTranspose2d(
  210. transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
  211. ),
  212. LayerNorm2d(transformer_dim // 4),
  213. nn.GELU(),
  214. nn.ConvTranspose2d(
  215. transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
  216. ),
  217. )
  218. self.embedding_maskfeature = nn.Sequential(
  219. nn.Conv2d(transformer_dim // 8, transformer_dim // 4, 3, 1, 1),
  220. LayerNorm2d(transformer_dim // 4),
  221. nn.GELU(),
  222. nn.Conv2d(transformer_dim // 4, transformer_dim // 8, 3, 1, 1),
  223. )
  224. def forward(
  225. self,
  226. image_embeddings: torch.Tensor,
  227. image_pe: torch.Tensor,
  228. sparse_prompt_embeddings: torch.Tensor,
  229. dense_prompt_embeddings: torch.Tensor,
  230. multimask_output: bool,
  231. hq_token_only: bool,
  232. interm_embeddings: torch.Tensor,
  233. ) -> Tuple[torch.Tensor, torch.Tensor]:
  234. """
  235. Predict masks given image and prompt embeddings.
  236. Arguments:
  237. image_embeddings (torch.Tensor): the embeddings from the ViT image encoder
  238. image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
  239. sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
  240. dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
  241. multimask_output (bool): Whether to return multiple masks or a single
  242. mask.
  243. Returns:
  244. torch.Tensor: batched predicted masks
  245. torch.Tensor: batched predictions of mask quality
  246. """
  247. vit_features = interm_embeddings[0].permute(
  248. 0, 3, 1, 2
  249. ) # early-layer ViT feature, after 1st global attention block in ViT
  250. hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(
  251. vit_features
  252. )
  253. masks, iou_pred = self.predict_masks(
  254. image_embeddings=image_embeddings,
  255. image_pe=image_pe,
  256. sparse_prompt_embeddings=sparse_prompt_embeddings,
  257. dense_prompt_embeddings=dense_prompt_embeddings,
  258. hq_features=hq_features,
  259. )
  260. # Select the correct mask or masks for output
  261. if multimask_output:
  262. # mask with highest score
  263. mask_slice = slice(1, self.num_mask_tokens - 1)
  264. iou_pred = iou_pred[:, mask_slice]
  265. iou_pred, max_iou_idx = torch.max(iou_pred, dim=1)
  266. iou_pred = iou_pred.unsqueeze(1)
  267. masks_multi = masks[:, mask_slice, :, :]
  268. masks_sam = masks_multi[
  269. torch.arange(masks_multi.size(0)), max_iou_idx
  270. ].unsqueeze(1)
  271. else:
  272. # singale mask output, default
  273. mask_slice = slice(0, 1)
  274. iou_pred = iou_pred[:, mask_slice]
  275. masks_sam = masks[:, mask_slice]
  276. masks_hq = masks[:, slice(self.num_mask_tokens - 1, self.num_mask_tokens)]
  277. if hq_token_only:
  278. masks = masks_hq
  279. else:
  280. masks = masks_sam + masks_hq
  281. # Prepare output
  282. return masks, iou_pred
  283. def predict_masks(
  284. self,
  285. image_embeddings: torch.Tensor,
  286. image_pe: torch.Tensor,
  287. sparse_prompt_embeddings: torch.Tensor,
  288. dense_prompt_embeddings: torch.Tensor,
  289. hq_features: torch.Tensor,
  290. ) -> Tuple[torch.Tensor, torch.Tensor]:
  291. """Predicts masks. See 'forward' for more details."""
  292. # Concatenate output tokens
  293. output_tokens = torch.cat(
  294. [self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight],
  295. dim=0,
  296. )
  297. output_tokens = output_tokens.unsqueeze(0).expand(
  298. sparse_prompt_embeddings.size(0), -1, -1
  299. )
  300. tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
  301. # Expand per-image data in batch direction to be per-mask
  302. src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
  303. src = src + dense_prompt_embeddings
  304. pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
  305. b, c, h, w = src.shape
  306. # Run the transformer
  307. hs, src = self.transformer(src, pos_src, tokens)
  308. iou_token_out = hs[:, 0, :]
  309. mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
  310. # Upscale mask embeddings and predict masks using the mask tokens
  311. src = src.transpose(1, 2).view(b, c, h, w)
  312. upscaled_embedding_sam = self.output_upscaling(src)
  313. upscaled_embedding_hq = self.embedding_maskfeature(
  314. upscaled_embedding_sam
  315. ) + hq_features.repeat(b, 1, 1, 1)
  316. hyper_in_list: List[torch.Tensor] = []
  317. for i in range(self.num_mask_tokens):
  318. if i < self.num_mask_tokens - 1:
  319. hyper_in_list.append(
  320. self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
  321. )
  322. else:
  323. hyper_in_list.append(self.hf_mlp(mask_tokens_out[:, i, :]))
  324. hyper_in = torch.stack(hyper_in_list, dim=1)
  325. b, c, h, w = upscaled_embedding_sam.shape
  326. masks_sam = (
  327. hyper_in[:, : self.num_mask_tokens - 1]
  328. @ upscaled_embedding_sam.view(b, c, h * w)
  329. ).view(b, -1, h, w)
  330. masks_sam_hq = (
  331. hyper_in[:, self.num_mask_tokens - 1 :]
  332. @ upscaled_embedding_hq.view(b, c, h * w)
  333. ).view(b, -1, h, w)
  334. masks = torch.cat([masks_sam, masks_sam_hq], dim=1)
  335. # Generate mask quality predictions
  336. iou_pred = self.iou_prediction_head(iou_token_out)
  337. return masks, iou_pred
  338. # Lightly adapted from
  339. # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
  340. class MLP(nn.Module):
  341. def __init__(
  342. self,
  343. input_dim: int,
  344. hidden_dim: int,
  345. output_dim: int,
  346. num_layers: int,
  347. sigmoid_output: bool = False,
  348. ) -> None:
  349. super().__init__()
  350. self.num_layers = num_layers
  351. h = [hidden_dim] * (num_layers - 1)
  352. self.layers = nn.ModuleList(
  353. nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
  354. )
  355. self.sigmoid_output = sigmoid_output
  356. def forward(self, x):
  357. for i, layer in enumerate(self.layers):
  358. x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  359. if self.sigmoid_output:
  360. x = F.sigmoid(x)
  361. return x