sam.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from typing import Any, Dict, List, Tuple
  6. import torch
  7. from torch import nn
  8. from torch.nn import functional as F
  9. from .image_encoder import ImageEncoderViT
  10. from .mask_decoder import MaskDecoder
  11. from .prompt_encoder import PromptEncoder
  12. class Sam(nn.Module):
  13. mask_threshold: float = 0.0
  14. image_format: str = "RGB"
  15. def __init__(
  16. self,
  17. image_encoder: ImageEncoderViT,
  18. prompt_encoder: PromptEncoder,
  19. mask_decoder: MaskDecoder,
  20. pixel_mean: List[float] = [123.675, 116.28, 103.53],
  21. pixel_std: List[float] = [58.395, 57.12, 57.375],
  22. ) -> None:
  23. """
  24. SAM predicts object masks from an image and input prompts.
  25. Arguments:
  26. image_encoder (ImageEncoderViT): The backbone used to encode the
  27. image into image embeddings that allow for efficient mask prediction.
  28. prompt_encoder (PromptEncoder): Encodes various types of input prompts.
  29. mask_decoder (MaskDecoder): Predicts masks from the image embeddings
  30. and encoded prompts.
  31. pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
  32. pixel_std (list(float)): Std values for normalizing pixels in the input image.
  33. """
  34. super().__init__()
  35. self.image_encoder = image_encoder
  36. self.prompt_encoder = prompt_encoder
  37. self.mask_decoder = mask_decoder
  38. self.register_buffer(
  39. "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False
  40. )
  41. self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
  42. @property
  43. def device(self) -> Any:
  44. return self.pixel_mean.device
  45. @torch.no_grad()
  46. def forward(
  47. self,
  48. batched_input: List[Dict[str, Any]],
  49. multimask_output: bool,
  50. ) -> List[Dict[str, torch.Tensor]]:
  51. """
  52. Predicts masks end-to-end from provided images and prompts.
  53. If prompts are not known in advance, using SamPredictor is
  54. recommended over calling the model directly.
  55. Arguments:
  56. batched_input (list(dict)): A list over input images, each a
  57. dictionary with the following keys. A prompt key can be
  58. excluded if it is not present.
  59. 'image': The image as a torch tensor in 3xHxW format,
  60. already transformed for input to the model.
  61. 'original_size': (tuple(int, int)) The original size of
  62. the image before transformation, as (H, W).
  63. 'point_coords': (torch.Tensor) Batched point prompts for
  64. this image, with shape BxNx2. Already transformed to the
  65. input frame of the model.
  66. 'point_labels': (torch.Tensor) Batched labels for point prompts,
  67. with shape BxN.
  68. 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
  69. Already transformed to the input frame of the model.
  70. 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
  71. in the form Bx1xHxW.
  72. multimask_output (bool): Whether the model should predict multiple
  73. disambiguating masks, or return a single mask.
  74. Returns:
  75. (list(dict)): A list over input images, where each element is
  76. as dictionary with the following keys.
  77. 'masks': (torch.Tensor) Batched binary mask predictions,
  78. with shape BxCxHxW, where B is the number of input promts,
  79. C is determiend by multimask_output, and (H, W) is the
  80. original size of the image.
  81. 'iou_predictions': (torch.Tensor) The model's predictions
  82. of mask quality, in shape BxC.
  83. 'low_res_logits': (torch.Tensor) Low resolution logits with
  84. shape BxCxHxW, where H=W=256. Can be passed as mask input
  85. to subsequent iterations of prediction.
  86. """
  87. input_images = torch.stack(
  88. [self.preprocess(x["image"]) for x in batched_input], dim=0
  89. )
  90. image_embeddings = self.image_encoder(input_images)
  91. outputs = []
  92. for image_record, curr_embedding in zip(batched_input, image_embeddings):
  93. if "point_coords" in image_record:
  94. points = (image_record["point_coords"], image_record["point_labels"])
  95. else:
  96. points = None
  97. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  98. points=points,
  99. boxes=image_record.get("boxes", None),
  100. masks=image_record.get("mask_inputs", None),
  101. )
  102. low_res_masks, iou_predictions = self.mask_decoder(
  103. image_embeddings=curr_embedding.unsqueeze(0),
  104. image_pe=self.prompt_encoder.get_dense_pe(),
  105. sparse_prompt_embeddings=sparse_embeddings,
  106. dense_prompt_embeddings=dense_embeddings,
  107. multimask_output=multimask_output,
  108. )
  109. masks = self.postprocess_masks(
  110. low_res_masks,
  111. input_size=image_record["image"].shape[-2:],
  112. original_size=image_record["original_size"],
  113. )
  114. masks = masks > self.mask_threshold
  115. outputs.append(
  116. {
  117. "masks": masks,
  118. "iou_predictions": iou_predictions,
  119. "low_res_logits": low_res_masks,
  120. }
  121. )
  122. return outputs
  123. def postprocess_masks(
  124. self,
  125. masks: torch.Tensor,
  126. input_size: Tuple[int, ...],
  127. original_size: Tuple[int, ...],
  128. ) -> torch.Tensor:
  129. """
  130. Remove padding and upscale masks to the original image size.
  131. Arguments:
  132. masks (torch.Tensor): Batched masks from the mask_decoder,
  133. in BxCxHxW format.
  134. input_size (tuple(int, int)): The size of the image input to the
  135. model, in (H, W) format. Used to remove padding.
  136. original_size (tuple(int, int)): The original size of the image
  137. before resizing for input to the model, in (H, W) format.
  138. Returns:
  139. (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
  140. is given by original_size.
  141. """
  142. masks = F.interpolate(
  143. masks,
  144. (self.image_encoder.img_size, self.image_encoder.img_size),
  145. mode="bilinear",
  146. align_corners=False,
  147. )
  148. masks = masks[..., : input_size[0], : input_size[1]]
  149. masks = F.interpolate(
  150. masks, original_size, mode="bilinear", align_corners=False
  151. )
  152. return masks
  153. def preprocess(self, x: torch.Tensor) -> torch.Tensor:
  154. """Normalize pixel values and pad to a square input."""
  155. # Normalize colors
  156. x = (x - self.pixel_mean) / self.pixel_std
  157. # Pad
  158. h, w = x.shape[-2:]
  159. padh = self.image_encoder.img_size - h
  160. padw = self.image_encoder.img_size - w
  161. x = F.pad(x, (0, padw, 0, padh))
  162. return x