transforms.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import torch
  6. import torch.nn as nn
  7. from torchvision.transforms import Normalize, Resize, ToTensor
  8. class SAM2Transforms(nn.Module):
  9. def __init__(
  10. self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
  11. ):
  12. """
  13. Transforms for SAM2.
  14. """
  15. super().__init__()
  16. self.resolution = resolution
  17. self.mask_threshold = mask_threshold
  18. self.max_hole_area = max_hole_area
  19. self.max_sprinkle_area = max_sprinkle_area
  20. self.mean = [0.485, 0.456, 0.406]
  21. self.std = [0.229, 0.224, 0.225]
  22. self.to_tensor = ToTensor()
  23. self.transforms = torch.jit.script(
  24. nn.Sequential(
  25. Resize((self.resolution, self.resolution)),
  26. Normalize(self.mean, self.std),
  27. )
  28. )
  29. def __call__(self, x):
  30. x = self.to_tensor(x)
  31. return self.transforms(x)
  32. def forward_batch(self, img_list):
  33. img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
  34. img_batch = torch.stack(img_batch, dim=0)
  35. return img_batch
  36. def transform_coords(
  37. self, coords: torch.Tensor, normalize=False, orig_hw=None
  38. ) -> torch.Tensor:
  39. """
  40. Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
  41. If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
  42. Returns
  43. Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
  44. """
  45. if normalize:
  46. assert orig_hw is not None
  47. h, w = orig_hw
  48. coords = coords.clone()
  49. coords[..., 0] = coords[..., 0] / w
  50. coords[..., 1] = coords[..., 1] / h
  51. coords = coords * self.resolution # unnormalize coords
  52. return coords
  53. def transform_boxes(
  54. self, boxes: torch.Tensor, normalize=False, orig_hw=None
  55. ) -> torch.Tensor:
  56. """
  57. Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
  58. if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
  59. """
  60. boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
  61. return boxes
  62. def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
  63. """
  64. Perform PostProcessing on output masks.
  65. """
  66. return masks