sam_hq.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from typing import Any, Dict, List, Tuple
  6. import torch
  7. from torch import nn
  8. from torch.nn import functional as F
  9. from .image_encoder import ImageEncoderViT
  10. from .mask_decoder import MaskDecoder
  11. from .prompt_encoder import PromptEncoder
  12. class SamHQ(nn.Module):
  13. mask_threshold: float = 0.0
  14. image_format: str = "RGB"
  15. def __init__(
  16. self,
  17. image_encoder: ImageEncoderViT,
  18. prompt_encoder: PromptEncoder,
  19. mask_decoder: MaskDecoder,
  20. pixel_mean: List[float] = [123.675, 116.28, 103.53],
  21. pixel_std: List[float] = [58.395, 57.12, 57.375],
  22. ) -> None:
  23. """
  24. SAM predicts object masks from an image and input prompts.
  25. Arguments:
  26. image_encoder (ImageEncoderViT): The backbone used to encode the
  27. image into image embeddings that allow for efficient mask prediction.
  28. prompt_encoder (PromptEncoder): Encodes various types of input prompts.
  29. mask_decoder (MaskDecoder): Predicts masks from the image embeddings
  30. and encoded prompts.
  31. pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
  32. pixel_std (list(float)): Std values for normalizing pixels in the input image.
  33. """
  34. super().__init__()
  35. self.image_encoder = image_encoder
  36. self.prompt_encoder = prompt_encoder
  37. self.mask_decoder = mask_decoder
  38. self.register_buffer(
  39. "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False
  40. )
  41. self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
  42. @property
  43. def device(self) -> Any:
  44. return self.pixel_mean.device
  45. def forward(
  46. self,
  47. batched_input: List[Dict[str, Any]],
  48. multimask_output: bool,
  49. hq_token_only: bool = False,
  50. ) -> List[Dict[str, torch.Tensor]]:
  51. """
  52. Predicts masks end-to-end from provided images and prompts.
  53. If prompts are not known in advance, using SamPredictor is
  54. recommended over calling the model directly.
  55. Arguments:
  56. batched_input (list(dict)): A list over input images, each a
  57. dictionary with the following keys. A prompt key can be
  58. excluded if it is not present.
  59. 'image': The image as a torch tensor in 3xHxW format,
  60. already transformed for input to the model.
  61. 'original_size': (tuple(int, int)) The original size of
  62. the image before transformation, as (H, W).
  63. 'point_coords': (torch.Tensor) Batched point prompts for
  64. this image, with shape BxNx2. Already transformed to the
  65. input frame of the model.
  66. 'point_labels': (torch.Tensor) Batched labels for point prompts,
  67. with shape BxN.
  68. 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4.
  69. Already transformed to the input frame of the model.
  70. 'mask_inputs': (torch.Tensor) Batched mask inputs to the model,
  71. in the form Bx1xHxW.
  72. multimask_output (bool): Whether the model should predict multiple
  73. disambiguating masks, or return a single mask.
  74. Returns:
  75. (list(dict)): A list over input images, where each element is
  76. as dictionary with the following keys.
  77. 'masks': (torch.Tensor) Batched binary mask predictions,
  78. with shape BxCxHxW, where B is the number of input prompts,
  79. C is determined by multimask_output, and (H, W) is the
  80. original size of the image.
  81. 'iou_predictions': (torch.Tensor) The model's predictions
  82. of mask quality, in shape BxC.
  83. 'low_res_logits': (torch.Tensor) Low resolution logits with
  84. shape BxCxHxW, where H=W=256. Can be passed as mask input
  85. to subsequent iterations of prediction.
  86. """
  87. input_images = torch.stack(
  88. [self.preprocess(x["image"]) for x in batched_input], dim=0
  89. )
  90. image_embeddings, interm_embeddings = self.image_encoder(input_images)
  91. interm_embeddings = interm_embeddings[0] # early layer
  92. outputs = []
  93. for image_record, curr_embedding, curr_interm in zip(
  94. batched_input, image_embeddings, interm_embeddings
  95. ):
  96. if "point_coords" in image_record:
  97. points = (image_record["point_coords"], image_record["point_labels"])
  98. else:
  99. points = None
  100. sparse_embeddings, dense_embeddings = self.prompt_encoder(
  101. points=points,
  102. boxes=image_record.get("boxes", None),
  103. masks=image_record.get("mask_inputs", None),
  104. )
  105. low_res_masks, iou_predictions = self.mask_decoder(
  106. image_embeddings=curr_embedding.unsqueeze(0),
  107. image_pe=self.prompt_encoder.get_dense_pe(),
  108. sparse_prompt_embeddings=sparse_embeddings,
  109. dense_prompt_embeddings=dense_embeddings,
  110. multimask_output=multimask_output,
  111. hq_token_only=hq_token_only,
  112. interm_embeddings=curr_interm.unsqueeze(0).unsqueeze(0),
  113. )
  114. masks = self.postprocess_masks(
  115. low_res_masks,
  116. input_size=image_record["image"].shape[-2:],
  117. original_size=image_record["original_size"],
  118. )
  119. masks = masks > self.mask_threshold
  120. outputs.append(
  121. {
  122. "masks": masks,
  123. "iou_predictions": iou_predictions,
  124. "low_res_logits": low_res_masks,
  125. }
  126. )
  127. return outputs
  128. def postprocess_masks(
  129. self,
  130. masks: torch.Tensor,
  131. input_size: Tuple[int, ...],
  132. original_size: Tuple[int, ...],
  133. ) -> torch.Tensor:
  134. """
  135. Remove padding and upscale masks to the original image size.
  136. Arguments:
  137. masks (torch.Tensor): Batched masks from the mask_decoder,
  138. in BxCxHxW format.
  139. input_size (tuple(int, int)): The size of the image input to the
  140. model, in (H, W) format. Used to remove padding.
  141. original_size (tuple(int, int)): The original size of the image
  142. before resizing for input to the model, in (H, W) format.
  143. Returns:
  144. (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
  145. is given by original_size.
  146. """
  147. masks = F.interpolate(
  148. masks,
  149. (self.image_encoder.img_size, self.image_encoder.img_size),
  150. mode="bilinear",
  151. align_corners=False,
  152. )
  153. masks = masks[..., : input_size[0], : input_size[1]]
  154. masks = F.interpolate(
  155. masks, original_size, mode="bilinear", align_corners=False
  156. )
  157. return masks
  158. def preprocess(self, x: torch.Tensor) -> torch.Tensor:
  159. """Normalize pixel values and pad to a square input."""
  160. # Normalize colors
  161. x = (x - self.pixel_mean) / self.pixel_std
  162. # Pad
  163. h, w = x.shape[-2:]
  164. padh = self.image_encoder.img_size - h
  165. padw = self.image_encoder.img_size - w
  166. x = F.pad(x, (0, padw, 0, padh))
  167. return x