prompt_encoder.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from typing import Any, Optional, Tuple, Type
  6. import numpy as np
  7. import torch
  8. from torch import nn
  9. from .common import LayerNorm2d
  10. class PromptEncoder(nn.Module):
  11. def __init__(
  12. self,
  13. embed_dim: int,
  14. image_embedding_size: Tuple[int, int],
  15. input_image_size: Tuple[int, int],
  16. mask_in_chans: int,
  17. activation: Type[nn.Module] = nn.GELU,
  18. ) -> None:
  19. """
  20. Encodes prompts for input to SAM's mask decoder.
  21. Arguments:
  22. embed_dim (int): The prompts' embedding dimension
  23. image_embedding_size (tuple(int, int)): The spatial size of the
  24. image embedding, as (H, W).
  25. input_image_size (int): The padded size of the image as input
  26. to the image encoder, as (H, W).
  27. mask_in_chans (int): The number of hidden channels used for
  28. encoding input masks.
  29. activation (nn.Module): The activation to use when encoding
  30. input masks.
  31. """
  32. super().__init__()
  33. self.embed_dim = embed_dim
  34. self.input_image_size = input_image_size
  35. self.image_embedding_size = image_embedding_size
  36. self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
  37. self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
  38. point_embeddings = [
  39. nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
  40. ]
  41. self.point_embeddings = nn.ModuleList(point_embeddings)
  42. self.not_a_point_embed = nn.Embedding(1, embed_dim)
  43. self.mask_input_size = (
  44. 4 * image_embedding_size[0],
  45. 4 * image_embedding_size[1],
  46. )
  47. self.mask_downscaling = nn.Sequential(
  48. nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
  49. LayerNorm2d(mask_in_chans // 4),
  50. activation(),
  51. nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
  52. LayerNorm2d(mask_in_chans),
  53. activation(),
  54. nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
  55. )
  56. self.no_mask_embed = nn.Embedding(1, embed_dim)
  57. def get_dense_pe(self) -> torch.Tensor:
  58. """
  59. Returns the positional encoding used to encode point prompts,
  60. applied to a dense set of points the shape of the image encoding.
  61. Returns:
  62. torch.Tensor: Positional encoding with shape
  63. 1x(embed_dim)x(embedding_h)x(embedding_w)
  64. """
  65. return self.pe_layer(self.image_embedding_size).unsqueeze(0)
  66. def _embed_points(
  67. self,
  68. points: torch.Tensor,
  69. labels: torch.Tensor,
  70. pad: bool,
  71. ) -> torch.Tensor:
  72. """Embeds point prompts."""
  73. points = points + 0.5 # Shift to center of pixel
  74. if pad:
  75. padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
  76. padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
  77. points = torch.cat([points, padding_point], dim=1)
  78. labels = torch.cat([labels, padding_label], dim=1)
  79. point_embedding = self.pe_layer.forward_with_coords(
  80. points, self.input_image_size
  81. )
  82. point_embedding[labels == -1] = 0.0
  83. point_embedding[labels == -1] += self.not_a_point_embed.weight
  84. point_embedding[labels == 0] += self.point_embeddings[0].weight
  85. point_embedding[labels == 1] += self.point_embeddings[1].weight
  86. return point_embedding
  87. def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
  88. """Embeds box prompts."""
  89. boxes = boxes + 0.5 # Shift to center of pixel
  90. coords = boxes.reshape(-1, 2, 2)
  91. corner_embedding = self.pe_layer.forward_with_coords(
  92. coords, self.input_image_size
  93. )
  94. corner_embedding[:, 0, :] += self.point_embeddings[2].weight
  95. corner_embedding[:, 1, :] += self.point_embeddings[3].weight
  96. return corner_embedding
  97. def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
  98. """Embeds mask inputs."""
  99. mask_embedding = self.mask_downscaling(masks)
  100. return mask_embedding
  101. def _get_batch_size(
  102. self,
  103. points: Optional[Tuple[torch.Tensor, torch.Tensor]],
  104. boxes: Optional[torch.Tensor],
  105. masks: Optional[torch.Tensor],
  106. ) -> int:
  107. """
  108. Gets the batch size of the output given the batch size of the input prompts.
  109. """
  110. if points is not None:
  111. return points[0].shape[0]
  112. elif boxes is not None:
  113. return boxes.shape[0]
  114. elif masks is not None:
  115. return masks.shape[0]
  116. else:
  117. return 1
  118. def _get_device(self) -> torch.device:
  119. return self.point_embeddings[0].weight.device
  120. def forward(
  121. self,
  122. points: Optional[Tuple[torch.Tensor, torch.Tensor]],
  123. boxes: Optional[torch.Tensor],
  124. masks: Optional[torch.Tensor],
  125. ) -> Tuple[torch.Tensor, torch.Tensor]:
  126. """
  127. Embeds different types of prompts, returning both sparse and dense
  128. embeddings.
  129. Arguments:
  130. points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
  131. and labels to embed.
  132. boxes (torch.Tensor or none): boxes to embed
  133. masks (torch.Tensor or none): masks to embed
  134. Returns:
  135. torch.Tensor: sparse embeddings for the points and boxes, with shape
  136. BxNx(embed_dim), where N is determined by the number of input points
  137. and boxes.
  138. torch.Tensor: dense embeddings for the masks, in the shape
  139. Bx(embed_dim)x(embed_H)x(embed_W)
  140. """
  141. bs = self._get_batch_size(points, boxes, masks)
  142. sparse_embeddings = torch.empty(
  143. (bs, 0, self.embed_dim), device=self._get_device()
  144. )
  145. if points is not None:
  146. coords, labels = points
  147. point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
  148. sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
  149. if boxes is not None:
  150. box_embeddings = self._embed_boxes(boxes)
  151. sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
  152. if masks is not None:
  153. dense_embeddings = self._embed_masks(masks)
  154. else:
  155. dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
  156. bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
  157. )
  158. return sparse_embeddings, dense_embeddings
  159. class PositionEmbeddingRandom(nn.Module):
  160. """
  161. Positional encoding using random spatial frequencies.
  162. """
  163. def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
  164. super().__init__()
  165. if scale is None or scale <= 0.0:
  166. scale = 1.0
  167. self.register_buffer(
  168. "positional_encoding_gaussian_matrix",
  169. scale * torch.randn((2, num_pos_feats)),
  170. )
  171. def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
  172. """Positionally encode points that are normalized to [0,1]."""
  173. # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
  174. coords = 2 * coords - 1
  175. coords = coords @ self.positional_encoding_gaussian_matrix
  176. coords = 2 * np.pi * coords
  177. # outputs d_1 x ... x d_n x C shape
  178. return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
  179. def forward(self, size: Tuple[int, int]) -> torch.Tensor:
  180. """Generate positional encoding for a grid of the specified size."""
  181. h, w = size
  182. device: Any = self.positional_encoding_gaussian_matrix.device
  183. grid = torch.ones((h, w), device=device, dtype=torch.float32)
  184. y_embed = grid.cumsum(dim=0) - 0.5
  185. x_embed = grid.cumsum(dim=1) - 0.5
  186. y_embed = y_embed / h
  187. x_embed = x_embed / w
  188. pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
  189. return pe.permute(2, 0, 1) # C x H x W
  190. def forward_with_coords(
  191. self, coords_input: torch.Tensor, image_size: Tuple[int, int]
  192. ) -> torch.Tensor:
  193. """Positionally encode points that are not normalized to [0,1]."""
  194. coords = coords_input.clone()
  195. coords[:, :, 0] = coords[:, :, 0] / image_size[1]
  196. coords[:, :, 1] = coords[:, :, 1] / image_size[0]
  197. return self._pe_encoding(coords.to(torch.float)) # B x N x C