sam2_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import copy
  6. from typing import Tuple
  7. import numpy as np
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from ..utils.misc import mask_to_box
  12. def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
  13. """
  14. Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
  15. that are temporally closest to the current frame at `frame_idx`. Here, we take
  16. - a) the closest conditioning frame before `frame_idx` (if any);
  17. - b) the closest conditioning frame after `frame_idx` (if any);
  18. - c) any other temporally closest conditioning frames until reaching a total
  19. of `max_cond_frame_num` conditioning frames.
  20. Outputs:
  21. - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
  22. - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
  23. """
  24. if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
  25. selected_outputs = cond_frame_outputs
  26. unselected_outputs = {}
  27. else:
  28. assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
  29. selected_outputs = {}
  30. # the closest conditioning frame before `frame_idx` (if any)
  31. idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
  32. if idx_before is not None:
  33. selected_outputs[idx_before] = cond_frame_outputs[idx_before]
  34. # the closest conditioning frame after `frame_idx` (if any)
  35. idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
  36. if idx_after is not None:
  37. selected_outputs[idx_after] = cond_frame_outputs[idx_after]
  38. # add other temporally closest conditioning frames until reaching a total
  39. # of `max_cond_frame_num` conditioning frames.
  40. num_remain = max_cond_frame_num - len(selected_outputs)
  41. inds_remain = sorted(
  42. (t for t in cond_frame_outputs if t not in selected_outputs),
  43. key=lambda x: abs(x - frame_idx),
  44. )[:num_remain]
  45. selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
  46. unselected_outputs = {
  47. t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
  48. }
  49. return selected_outputs, unselected_outputs
  50. def get_1d_sine_pe(pos_inds, dim, temperature=10000):
  51. """
  52. Get 1D sine positional embedding as in the original Transformer paper.
  53. """
  54. pe_dim = dim // 2
  55. dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
  56. dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
  57. pos_embed = pos_inds.unsqueeze(-1) / dim_t
  58. pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
  59. return pos_embed
  60. def get_activation_fn(activation):
  61. """Return an activation function given a string"""
  62. if activation == "relu":
  63. return F.relu
  64. if activation == "gelu":
  65. return F.gelu
  66. if activation == "glu":
  67. return F.glu
  68. raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
  69. def get_clones(module, N):
  70. return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
  71. class DropPath(nn.Module):
  72. # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
  73. def __init__(self, drop_prob=0.0, scale_by_keep=True):
  74. super(DropPath, self).__init__()
  75. self.drop_prob = drop_prob
  76. self.scale_by_keep = scale_by_keep
  77. def forward(self, x):
  78. if self.drop_prob == 0.0 or not self.training:
  79. return x
  80. keep_prob = 1 - self.drop_prob
  81. shape = (x.shape[0],) + (1,) * (x.ndim - 1)
  82. random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
  83. if keep_prob > 0.0 and self.scale_by_keep:
  84. random_tensor.div_(keep_prob)
  85. return x * random_tensor
  86. # Lightly adapted from
  87. # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
  88. class MLP(nn.Module):
  89. def __init__(
  90. self,
  91. input_dim: int,
  92. hidden_dim: int,
  93. output_dim: int,
  94. num_layers: int,
  95. activation: nn.Module = nn.ReLU,
  96. sigmoid_output: bool = False,
  97. ) -> None:
  98. super().__init__()
  99. self.num_layers = num_layers
  100. h = [hidden_dim] * (num_layers - 1)
  101. self.layers = nn.ModuleList(
  102. nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
  103. )
  104. self.sigmoid_output = sigmoid_output
  105. self.act = activation()
  106. def forward(self, x):
  107. for i, layer in enumerate(self.layers):
  108. x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
  109. if self.sigmoid_output:
  110. x = F.sigmoid(x)
  111. return x
  112. # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
  113. # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
  114. class LayerNorm2d(nn.Module):
  115. def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
  116. super().__init__()
  117. self.weight = nn.Parameter(torch.ones(num_channels))
  118. self.bias = nn.Parameter(torch.zeros(num_channels))
  119. self.eps = eps
  120. def forward(self, x: torch.Tensor) -> torch.Tensor:
  121. u = x.mean(1, keepdim=True)
  122. s = (x - u).pow(2).mean(1, keepdim=True)
  123. x = (x - u) / torch.sqrt(s + self.eps)
  124. x = self.weight[:, None, None] * x + self.bias[:, None, None]
  125. return x
  126. def sample_box_points(
  127. masks: torch.Tensor,
  128. noise: float = 0.1, # SAM default
  129. noise_bound: int = 20, # SAM default
  130. top_left_label: int = 2,
  131. bottom_right_label: int = 3,
  132. ) -> Tuple[np.array, np.array]:
  133. """
  134. Sample a noised version of the top left and bottom right corners of a given `bbox`
  135. Inputs:
  136. - masks: [B, 1, H,W] boxes, dtype=torch.Tensor
  137. - noise: noise as a fraction of box width and height, dtype=float
  138. - noise_bound: maximum amount of noise (in pure pixesl), dtype=int
  139. Returns:
  140. - box_coords: [B, num_pt, 2], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.float
  141. - box_labels: [B, num_pt], label 2 is reserverd for top left and 3 for bottom right corners, dtype=torch.int32
  142. """
  143. device = masks.device
  144. box_coords = mask_to_box(masks)
  145. B, _, H, W = masks.shape
  146. box_labels = torch.tensor(
  147. [top_left_label, bottom_right_label], dtype=torch.int, device=device
  148. ).repeat(B)
  149. if noise > 0.0:
  150. if not isinstance(noise_bound, torch.Tensor):
  151. noise_bound = torch.tensor(noise_bound, device=device)
  152. bbox_w = box_coords[..., 2] - box_coords[..., 0]
  153. bbox_h = box_coords[..., 3] - box_coords[..., 1]
  154. max_dx = torch.min(bbox_w * noise, noise_bound)
  155. max_dy = torch.min(bbox_h * noise, noise_bound)
  156. box_noise = 2 * torch.rand(B, 1, 4, device=device) - 1
  157. box_noise = box_noise * torch.stack((max_dx, max_dy, max_dx, max_dy), dim=-1)
  158. box_coords = box_coords + box_noise
  159. img_bounds = (
  160. torch.tensor([W, H, W, H], device=device) - 1
  161. ) # uncentered pixel coords
  162. box_coords.clamp_(torch.zeros_like(img_bounds), img_bounds) # In place clamping
  163. box_coords = box_coords.reshape(-1, 2, 2) # always 2 points
  164. box_labels = box_labels.reshape(-1, 2)
  165. return box_coords, box_labels
  166. def sample_random_points_from_errors(gt_masks, pred_masks, num_pt=1):
  167. """
  168. Sample `num_pt` random points (along with their labels) independently from the error regions.
  169. Inputs:
  170. - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
  171. - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
  172. - num_pt: int, number of points to sample independently for each of the B error maps
  173. Outputs:
  174. - points: [B, num_pt, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
  175. - labels: [B, num_pt], dtype=torch.int32, where 1 means positive clicks and 0 means
  176. negative clicks
  177. """
  178. if pred_masks is None: # if pred_masks is not provided, treat it as empty
  179. pred_masks = torch.zeros_like(gt_masks)
  180. assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
  181. assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
  182. assert num_pt >= 0
  183. B, _, H_im, W_im = gt_masks.shape
  184. device = gt_masks.device
  185. # false positive region, a new point sampled in this region should have
  186. # negative label to correct the FP error
  187. fp_masks = ~gt_masks & pred_masks
  188. # false negative region, a new point sampled in this region should have
  189. # positive label to correct the FN error
  190. fn_masks = gt_masks & ~pred_masks
  191. # whether the prediction completely match the ground-truth on each mask
  192. all_correct = torch.all((gt_masks == pred_masks).flatten(2), dim=2)
  193. all_correct = all_correct[..., None, None]
  194. # channel 0 is FP map, while channel 1 is FN map
  195. pts_noise = torch.rand(B, num_pt, H_im, W_im, 2, device=device)
  196. # sample a negative new click from FP region or a positive new click
  197. # from FN region, depend on where the maximum falls,
  198. # and in case the predictions are all correct (no FP or FN), we just
  199. # sample a negative click from the background region
  200. pts_noise[..., 0] *= fp_masks | (all_correct & ~gt_masks)
  201. pts_noise[..., 1] *= fn_masks
  202. pts_idx = pts_noise.flatten(2).argmax(dim=2)
  203. labels = (pts_idx % 2).to(torch.int32)
  204. pts_idx = pts_idx // 2
  205. pts_x = pts_idx % W_im
  206. pts_y = pts_idx // W_im
  207. points = torch.stack([pts_x, pts_y], dim=2).to(torch.float)
  208. return points, labels
  209. def sample_one_point_from_error_center(gt_masks, pred_masks, padding=True):
  210. """
  211. Sample 1 random point (along with its label) from the center of each error region,
  212. that is, the point with the largest distance to the boundary of each error region.
  213. This is the RITM sampling method from https://github.com/saic-vul/ritm_interactive_segmentation/blob/master/isegm/inference/clicker.py
  214. Inputs:
  215. - gt_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool
  216. - pred_masks: [B, 1, H_im, W_im] masks, dtype=torch.bool or None
  217. - padding: if True, pad with boundary of 1 px for distance transform
  218. Outputs:
  219. - points: [B, 1, 2], dtype=torch.float, contains (x, y) coordinates of each sampled point
  220. - labels: [B, 1], dtype=torch.int32, where 1 means positive clicks and 0 means negative clicks
  221. """
  222. import cv2
  223. if pred_masks is None:
  224. pred_masks = torch.zeros_like(gt_masks)
  225. assert gt_masks.dtype == torch.bool and gt_masks.size(1) == 1
  226. assert pred_masks.dtype == torch.bool and pred_masks.shape == gt_masks.shape
  227. B, _, _, W_im = gt_masks.shape
  228. device = gt_masks.device
  229. # false positive region, a new point sampled in this region should have
  230. # negative label to correct the FP error
  231. fp_masks = ~gt_masks & pred_masks
  232. # false negative region, a new point sampled in this region should have
  233. # positive label to correct the FN error
  234. fn_masks = gt_masks & ~pred_masks
  235. fp_masks = fp_masks.cpu().numpy()
  236. fn_masks = fn_masks.cpu().numpy()
  237. points = torch.zeros(B, 1, 2, dtype=torch.float)
  238. labels = torch.ones(B, 1, dtype=torch.int32)
  239. for b in range(B):
  240. fn_mask = fn_masks[b, 0]
  241. fp_mask = fp_masks[b, 0]
  242. if padding:
  243. fn_mask = np.pad(fn_mask, ((1, 1), (1, 1)), "constant")
  244. fp_mask = np.pad(fp_mask, ((1, 1), (1, 1)), "constant")
  245. # compute the distance of each point in FN/FP region to its boundary
  246. fn_mask_dt = cv2.distanceTransform(fn_mask.astype(np.uint8), cv2.DIST_L2, 0)
  247. fp_mask_dt = cv2.distanceTransform(fp_mask.astype(np.uint8), cv2.DIST_L2, 0)
  248. if padding:
  249. fn_mask_dt = fn_mask_dt[1:-1, 1:-1]
  250. fp_mask_dt = fp_mask_dt[1:-1, 1:-1]
  251. # take the point in FN/FP region with the largest distance to its boundary
  252. fn_mask_dt_flat = fn_mask_dt.reshape(-1)
  253. fp_mask_dt_flat = fp_mask_dt.reshape(-1)
  254. fn_argmax = np.argmax(fn_mask_dt_flat)
  255. fp_argmax = np.argmax(fp_mask_dt_flat)
  256. is_positive = fn_mask_dt_flat[fn_argmax] > fp_mask_dt_flat[fp_argmax]
  257. pt_idx = fn_argmax if is_positive else fp_argmax
  258. points[b, 0, 0] = pt_idx % W_im # x
  259. points[b, 0, 1] = pt_idx // W_im # y
  260. labels[b, 0] = int(is_positive)
  261. points = points.to(device)
  262. labels = labels.to(device)
  263. return points, labels
  264. def get_next_point(gt_masks, pred_masks, method):
  265. if method == "uniform":
  266. return sample_random_points_from_errors(gt_masks, pred_masks)
  267. elif method == "center":
  268. return sample_one_point_from_error_center(gt_masks, pred_masks)
  269. else:
  270. raise ValueError(f"unknown sampling method {method}")