predictor.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from typing import Optional, Tuple
  6. import numpy as np
  7. import torch
  8. from .modeling import Sam
  9. class SamPredictor:
  10. def __init__(
  11. self,
  12. sam_model: Sam,
  13. ) -> None:
  14. """
  15. Uses SAM to calculate the image embedding for an image, and then
  16. allow repeated, efficient mask prediction given prompts.
  17. Arguments:
  18. sam_model (Sam): The model to use for mask prediction.
  19. """
  20. super().__init__()
  21. self.model = sam_model
  22. from .utils.transforms import ResizeLongestSide
  23. self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
  24. self.reset_image()
  25. def set_image(
  26. self,
  27. image: np.ndarray,
  28. image_format: str = "RGB",
  29. ) -> None:
  30. """
  31. Calculates the image embeddings for the provided image, allowing
  32. masks to be predicted with the 'predict' method.
  33. Arguments:
  34. image (np.ndarray): The image for calculating masks. Expects an
  35. image in HWC uint8 format, with pixel values in [0, 255].
  36. image_format (str): The color format of the image, in ['RGB', 'BGR'].
  37. """
  38. assert image_format in [
  39. "RGB",
  40. "BGR",
  41. ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
  42. if image_format != self.model.image_format:
  43. image = image[..., ::-1]
  44. # Transform the image to the form expected by the model
  45. input_image = self.transform.apply_image(image)
  46. input_image_torch = torch.as_tensor(input_image, device=self.device)
  47. input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[
  48. None, :, :, :
  49. ]
  50. self.set_torch_image(input_image_torch, image.shape[:2])
  51. @torch.no_grad()
  52. def set_torch_image(
  53. self,
  54. transformed_image: torch.Tensor,
  55. original_image_size: Tuple[int, ...],
  56. ) -> None:
  57. """
  58. Calculates the image embeddings for the provided image, allowing
  59. masks to be predicted with the 'predict' method. Expects the input
  60. image to be already transformed to the format expected by the model.
  61. Arguments:
  62. transformed_image (torch.Tensor): The input image, with shape
  63. 1x3xHxW, which has been transformed with ResizeLongestSide.
  64. original_image_size (tuple(int, int)): The size of the image
  65. before transformation, in (H, W) format.
  66. """
  67. assert (
  68. len(transformed_image.shape) == 4
  69. and transformed_image.shape[1] == 3
  70. and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
  71. ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
  72. self.reset_image()
  73. self.original_size = original_image_size
  74. self.input_size = tuple(transformed_image.shape[-2:])
  75. input_image = self.model.preprocess(transformed_image)
  76. self.features = self.model.image_encoder(input_image)
  77. self.is_image_set = True
  78. def predict(
  79. self,
  80. point_coords: Optional[np.ndarray] = None,
  81. point_labels: Optional[np.ndarray] = None,
  82. box: Optional[np.ndarray] = None,
  83. mask_input: Optional[np.ndarray] = None,
  84. multimask_output: bool = True,
  85. return_logits: bool = False,
  86. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  87. """
  88. Predict masks for the given input prompts, using the currently set image.
  89. Arguments:
  90. point_coords (np.ndarray or None): A Nx2 array of point prompts to the
  91. model. Each point is in (X,Y) in pixels.
  92. point_labels (np.ndarray or None): A length N array of labels for the
  93. point prompts. 1 indicates a foreground point and 0 indicates a
  94. background point.
  95. box (np.ndarray or None): A length 4 array given a box prompt to the
  96. model, in XYXY format.
  97. mask_input (np.ndarray): A low resolution mask input to the model, typically
  98. coming from a previous prediction iteration. Has form 1xHxW, where
  99. for SAM, H=W=256.
  100. multimask_output (bool): If true, the model will return three masks.
  101. For ambiguous input prompts (such as a single click), this will often
  102. produce better masks than a single prediction. If only a single
  103. mask is needed, the model's predicted quality score can be used
  104. to select the best mask. For non-ambiguous prompts, such as multiple
  105. input prompts, multimask_output=False can give better results.
  106. return_logits (bool): If true, returns un-thresholded masks logits
  107. instead of a binary mask.
  108. Returns:
  109. (np.ndarray): The output masks in CxHxW format, where C is the
  110. number of masks, and (H, W) is the original image size.
  111. (np.ndarray): An array of length C containing the model's
  112. predictions for the quality of each mask.
  113. (np.ndarray): An array of shape CxHxW, where C is the number
  114. of masks and H=W=256. These low resolution logits can be passed to
  115. a subsequent iteration as mask input.
  116. """
  117. if not self.is_image_set:
  118. raise RuntimeError(
  119. "An image must be set with .set_image(...) before mask prediction."
  120. )
  121. # Transform input prompts
  122. coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
  123. if point_coords is not None:
  124. assert (
  125. point_labels is not None
  126. ), "point_labels must be supplied if point_coords is supplied."
  127. point_coords = self.transform.apply_coords(point_coords, self.original_size)
  128. coords_torch = torch.as_tensor(
  129. point_coords, dtype=torch.float, device=self.device
  130. )
  131. labels_torch = torch.as_tensor(
  132. point_labels, dtype=torch.int, device=self.device
  133. )
  134. coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
  135. if box is not None:
  136. box = self.transform.apply_boxes(box, self.original_size)
  137. box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
  138. box_torch = box_torch[None, :]
  139. if mask_input is not None:
  140. mask_input_torch = torch.as_tensor(
  141. mask_input, dtype=torch.float, device=self.device
  142. )
  143. mask_input_torch = mask_input_torch[None, :, :, :]
  144. masks, iou_predictions, low_res_masks = self.predict_torch(
  145. coords_torch,
  146. labels_torch,
  147. box_torch,
  148. mask_input_torch,
  149. multimask_output,
  150. return_logits=return_logits,
  151. )
  152. masks = masks[0].detach().cpu().numpy()
  153. iou_predictions = iou_predictions[0].detach().cpu().numpy()
  154. low_res_masks = low_res_masks[0].detach().cpu().numpy()
  155. return masks, iou_predictions, low_res_masks
  156. @torch.no_grad()
  157. def predict_torch(
  158. self,
  159. point_coords: Optional[torch.Tensor],
  160. point_labels: Optional[torch.Tensor],
  161. boxes: Optional[torch.Tensor] = None,
  162. mask_input: Optional[torch.Tensor] = None,
  163. multimask_output: bool = True,
  164. return_logits: bool = False,
  165. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  166. """
  167. Predict masks for the given input prompts, using the currently set image.
  168. Input prompts are batched torch tensors and are expected to already be
  169. transformed to the input frame using ResizeLongestSide.
  170. Arguments:
  171. point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
  172. model. Each point is in (X,Y) in pixels.
  173. point_labels (torch.Tensor or None): A BxN array of labels for the
  174. point prompts. 1 indicates a foreground point and 0 indicates a
  175. background point.
  176. box (np.ndarray or None): A Bx4 array given a box prompt to the
  177. model, in XYXY format.
  178. mask_input (np.ndarray): A low resolution mask input to the model, typically
  179. coming from a previous prediction iteration. Has form Bx1xHxW, where
  180. for SAM, H=W=256. Masks returned by a previous iteration of the
  181. predict method do not need further transformation.
  182. multimask_output (bool): If true, the model will return three masks.
  183. For ambiguous input prompts (such as a single click), this will often
  184. produce better masks than a single prediction. If only a single
  185. mask is needed, the model's predicted quality score can be used
  186. to select the best mask. For non-ambiguous prompts, such as multiple
  187. input prompts, multimask_output=False can give better results.
  188. return_logits (bool): If true, returns un-thresholded masks logits
  189. instead of a binary mask.
  190. Returns:
  191. (torch.Tensor): The output masks in BxCxHxW format, where C is the
  192. number of masks, and (H, W) is the original image size.
  193. (torch.Tensor): An array of shape BxC containing the model's
  194. predictions for the quality of each mask.
  195. (torch.Tensor): An array of shape BxCxHxW, where C is the number
  196. of masks and H=W=256. These low res logits can be passed to
  197. a subsequent iteration as mask input.
  198. """
  199. if not self.is_image_set:
  200. raise RuntimeError(
  201. "An image must be set with .set_image(...) before mask prediction."
  202. )
  203. if point_coords is not None:
  204. points = (point_coords, point_labels)
  205. else:
  206. points = None
  207. # Embed prompts
  208. sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
  209. points=points,
  210. boxes=boxes,
  211. masks=mask_input,
  212. )
  213. # Predict masks
  214. low_res_masks, iou_predictions = self.model.mask_decoder(
  215. image_embeddings=self.features,
  216. image_pe=self.model.prompt_encoder.get_dense_pe(),
  217. sparse_prompt_embeddings=sparse_embeddings,
  218. dense_prompt_embeddings=dense_embeddings,
  219. multimask_output=multimask_output,
  220. )
  221. # Upscale the masks to the original image resolution
  222. masks = self.model.postprocess_masks(
  223. low_res_masks, self.input_size, self.original_size
  224. )
  225. if not return_logits:
  226. masks = masks > self.model.mask_threshold
  227. return masks, iou_predictions, low_res_masks
  228. def get_image_embedding(self) -> torch.Tensor:
  229. """
  230. Returns the image embeddings for the currently set image, with
  231. shape 1xCxHxW, where C is the embedding dimension and (H,W) are
  232. the embedding spatial dimension of SAM (typically C=256, H=W=64).
  233. """
  234. if not self.is_image_set:
  235. raise RuntimeError(
  236. "An image must be set with .set_image(...) to generate an embedding."
  237. )
  238. assert (
  239. self.features is not None
  240. ), "Features must exist if an image has been set."
  241. return self.features
  242. @property
  243. def device(self) -> torch.device:
  244. return self.model.device
  245. def reset_image(self) -> None:
  246. """Resets the currently set image."""
  247. self.is_image_set = False
  248. self.features = None
  249. self.orig_h = None
  250. self.orig_w = None
  251. self.input_h = None
  252. self.input_w = None