predictor_hq.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from typing import Optional, Tuple
  6. import numpy as np
  7. import torch
  8. from .modeling import Sam
  9. from .utils.transforms import ResizeLongestSide
  10. class SamHQPredictor:
  11. def __init__(
  12. self,
  13. sam_model: Sam,
  14. ) -> None:
  15. """
  16. Uses SAM to calculate the image embedding for an image, and then
  17. allow repeated, efficient mask prediction given prompts.
  18. Arguments:
  19. sam_model (Sam): The model to use for mask prediction.
  20. """
  21. super().__init__()
  22. self.model = sam_model
  23. self.transform = ResizeLongestSide(sam_model.image_encoder.img_size)
  24. self.reset_image()
  25. def set_image(
  26. self,
  27. image: np.ndarray,
  28. image_format: str = "RGB",
  29. ) -> None:
  30. """
  31. Calculates the image embeddings for the provided image, allowing
  32. masks to be predicted with the 'predict' method.
  33. Arguments:
  34. image (np.ndarray): The image for calculating masks. Expects an
  35. image in HWC uint8 format, with pixel values in [0, 255].
  36. image_format (str): The color format of the image, in ['RGB', 'BGR'].
  37. """
  38. assert image_format in [
  39. "RGB",
  40. "BGR",
  41. ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
  42. # import pdb;pdb.set_trace()
  43. if image_format != self.model.image_format:
  44. image = image[..., ::-1]
  45. # Transform the image to the form expected by the model
  46. # import pdb;pdb.set_trace()
  47. input_image = self.transform.apply_image(image)
  48. input_image_torch = torch.as_tensor(input_image, device=self.device)
  49. input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[
  50. None, :, :, :
  51. ]
  52. self.set_torch_image(input_image_torch, image.shape[:2])
  53. @torch.no_grad()
  54. def set_torch_image(
  55. self,
  56. transformed_image: torch.Tensor,
  57. original_image_size: Tuple[int, ...],
  58. ) -> None:
  59. """
  60. Calculates the image embeddings for the provided image, allowing
  61. masks to be predicted with the 'predict' method. Expects the input
  62. image to be already transformed to the format expected by the model.
  63. Arguments:
  64. transformed_image (torch.Tensor): The input image, with shape
  65. 1x3xHxW, which has been transformed with ResizeLongestSide.
  66. original_image_size (tuple(int, int)): The size of the image
  67. before transformation, in (H, W) format.
  68. """
  69. assert (
  70. len(transformed_image.shape) == 4
  71. and transformed_image.shape[1] == 3
  72. and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size
  73. ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}."
  74. self.reset_image()
  75. self.original_size = original_image_size
  76. self.input_size = tuple(transformed_image.shape[-2:])
  77. input_image = self.model.preprocess(transformed_image)
  78. self.features, self.interm_features = self.model.image_encoder(input_image)
  79. self.is_image_set = True
  80. def predict(
  81. self,
  82. point_coords: Optional[np.ndarray] = None,
  83. point_labels: Optional[np.ndarray] = None,
  84. box: Optional[np.ndarray] = None,
  85. mask_input: Optional[np.ndarray] = None,
  86. multimask_output: bool = True,
  87. return_logits: bool = False,
  88. hq_token_only: bool = False,
  89. ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
  90. """
  91. Predict masks for the given input prompts, using the currently set image.
  92. Arguments:
  93. point_coords (np.ndarray or None): A Nx2 array of point prompts to the
  94. model. Each point is in (X,Y) in pixels.
  95. point_labels (np.ndarray or None): A length N array of labels for the
  96. point prompts. 1 indicates a foreground point and 0 indicates a
  97. background point.
  98. box (np.ndarray or None): A length 4 array given a box prompt to the
  99. model, in XYXY format.
  100. mask_input (np.ndarray): A low resolution mask input to the model, typically
  101. coming from a previous prediction iteration. Has form 1xHxW, where
  102. for SAM, H=W=256.
  103. multimask_output (bool): If true, the model will return three masks.
  104. For ambiguous input prompts (such as a single click), this will often
  105. produce better masks than a single prediction. If only a single
  106. mask is needed, the model's predicted quality score can be used
  107. to select the best mask. For non-ambiguous prompts, such as multiple
  108. input prompts, multimask_output=False can give better results.
  109. return_logits (bool): If true, returns un-thresholded masks logits
  110. instead of a binary mask.
  111. Returns:
  112. (np.ndarray): The output masks in CxHxW format, where C is the
  113. number of masks, and (H, W) is the original image size.
  114. (np.ndarray): An array of length C containing the model's
  115. predictions for the quality of each mask.
  116. (np.ndarray): An array of shape CxHxW, where C is the number
  117. of masks and H=W=256. These low resolution logits can be passed to
  118. a subsequent iteration as mask input.
  119. """
  120. if not self.is_image_set:
  121. raise RuntimeError(
  122. "An image must be set with .set_image(...) before mask prediction."
  123. )
  124. # Transform input prompts
  125. coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
  126. if point_coords is not None:
  127. assert (
  128. point_labels is not None
  129. ), "point_labels must be supplied if point_coords is supplied."
  130. point_coords = self.transform.apply_coords(point_coords, self.original_size)
  131. coords_torch = torch.as_tensor(
  132. point_coords, dtype=torch.float, device=self.device
  133. )
  134. labels_torch = torch.as_tensor(
  135. point_labels, dtype=torch.int, device=self.device
  136. )
  137. coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
  138. if box is not None:
  139. box = self.transform.apply_boxes(box, self.original_size)
  140. box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
  141. box_torch = box_torch[None, :]
  142. if mask_input is not None:
  143. mask_input_torch = torch.as_tensor(
  144. mask_input, dtype=torch.float, device=self.device
  145. )
  146. mask_input_torch = mask_input_torch[None, :, :, :]
  147. masks, iou_predictions, low_res_masks = self.predict_torch(
  148. coords_torch,
  149. labels_torch,
  150. box_torch,
  151. mask_input_torch,
  152. multimask_output,
  153. return_logits=return_logits,
  154. hq_token_only=hq_token_only,
  155. )
  156. masks_np = masks[0].detach().cpu().numpy()
  157. iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
  158. low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
  159. return masks_np, iou_predictions_np, low_res_masks_np
  160. @torch.no_grad()
  161. def predict_torch(
  162. self,
  163. point_coords: Optional[torch.Tensor],
  164. point_labels: Optional[torch.Tensor],
  165. boxes: Optional[torch.Tensor] = None,
  166. mask_input: Optional[torch.Tensor] = None,
  167. multimask_output: bool = True,
  168. return_logits: bool = False,
  169. hq_token_only: bool = False,
  170. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  171. """
  172. Predict masks for the given input prompts, using the currently set image.
  173. Input prompts are batched torch tensors and are expected to already be
  174. transformed to the input frame using ResizeLongestSide.
  175. Arguments:
  176. point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
  177. model. Each point is in (X,Y) in pixels.
  178. point_labels (torch.Tensor or None): A BxN array of labels for the
  179. point prompts. 1 indicates a foreground point and 0 indicates a
  180. background point.
  181. boxes (np.ndarray or None): A Bx4 array given a box prompt to the
  182. model, in XYXY format.
  183. mask_input (np.ndarray): A low resolution mask input to the model, typically
  184. coming from a previous prediction iteration. Has form Bx1xHxW, where
  185. for SAM, H=W=256. Masks returned by a previous iteration of the
  186. predict method do not need further transformation.
  187. multimask_output (bool): If true, the model will return three masks.
  188. For ambiguous input prompts (such as a single click), this will often
  189. produce better masks than a single prediction. If only a single
  190. mask is needed, the model's predicted quality score can be used
  191. to select the best mask. For non-ambiguous prompts, such as multiple
  192. input prompts, multimask_output=False can give better results.
  193. return_logits (bool): If true, returns un-thresholded masks logits
  194. instead of a binary mask.
  195. Returns:
  196. (torch.Tensor): The output masks in BxCxHxW format, where C is the
  197. number of masks, and (H, W) is the original image size.
  198. (torch.Tensor): An array of shape BxC containing the model's
  199. predictions for the quality of each mask.
  200. (torch.Tensor): An array of shape BxCxHxW, where C is the number
  201. of masks and H=W=256. These low res logits can be passed to
  202. a subsequent iteration as mask input.
  203. """
  204. if not self.is_image_set:
  205. raise RuntimeError(
  206. "An image must be set with .set_image(...) before mask prediction."
  207. )
  208. if point_coords is not None:
  209. points = (point_coords, point_labels)
  210. else:
  211. points = None
  212. # Embed prompts
  213. sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
  214. points=points,
  215. boxes=boxes,
  216. masks=mask_input,
  217. )
  218. # Predict masks
  219. low_res_masks, iou_predictions = self.model.mask_decoder(
  220. image_embeddings=self.features,
  221. image_pe=self.model.prompt_encoder.get_dense_pe(),
  222. sparse_prompt_embeddings=sparse_embeddings,
  223. dense_prompt_embeddings=dense_embeddings,
  224. multimask_output=multimask_output,
  225. hq_token_only=hq_token_only,
  226. interm_embeddings=self.interm_features,
  227. )
  228. # Upscale the masks to the original image resolution
  229. masks = self.model.postprocess_masks(
  230. low_res_masks, self.input_size, self.original_size
  231. )
  232. if not return_logits:
  233. masks = masks > self.model.mask_threshold
  234. return masks, iou_predictions, low_res_masks
  235. def get_image_embedding(self) -> torch.Tensor:
  236. """
  237. Returns the image embeddings for the currently set image, with
  238. shape 1xCxHxW, where C is the embedding dimension and (H,W) are
  239. the embedding spatial dimension of SAM (typically C=256, H=W=64).
  240. """
  241. if not self.is_image_set:
  242. raise RuntimeError(
  243. "An image must be set with .set_image(...) to generate an embedding."
  244. )
  245. assert (
  246. self.features is not None
  247. ), "Features must exist if an image has been set."
  248. return self.features
  249. @property
  250. def device(self) -> torch.device:
  251. return self.model.device
  252. def reset_image(self) -> None:
  253. """Resets the currently set image."""
  254. self.is_image_set = False
  255. self.features = None
  256. self.orig_h = None
  257. self.orig_w = None
  258. self.input_h = None
  259. self.input_w = None