transforms.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from copy import deepcopy
  6. from typing import Tuple
  7. import numpy as np
  8. import torch
  9. from torch.nn import functional as F
  10. from torchvision.transforms.functional import resize # type: ignore
  11. from torchvision.transforms.functional import to_pil_image
  12. class ResizeLongestSide:
  13. """
  14. Resizes images to longest side 'target_length', as well as provides
  15. methods for resizing coordinates and boxes. Provides methods for
  16. transforming both numpy array and batched torch tensors.
  17. """
  18. def __init__(self, target_length: int) -> None:
  19. self.target_length = target_length
  20. def apply_image(self, image: np.ndarray) -> np.ndarray:
  21. """
  22. Expects a numpy array with shape HxWxC in uint8 format.
  23. """
  24. target_size = self.get_preprocess_shape(
  25. image.shape[0], image.shape[1], self.target_length
  26. )
  27. return np.array(resize(to_pil_image(image), target_size))
  28. def apply_coords(
  29. self, coords: np.ndarray, original_size: Tuple[int, ...]
  30. ) -> np.ndarray:
  31. """
  32. Expects a numpy array of length 2 in the final dimension. Requires the
  33. original image size in (H, W) format.
  34. """
  35. old_h, old_w = original_size
  36. new_h, new_w = self.get_preprocess_shape(
  37. original_size[0], original_size[1], self.target_length
  38. )
  39. coords = deepcopy(coords).astype(float)
  40. coords[..., 0] = coords[..., 0] * (new_w / old_w)
  41. coords[..., 1] = coords[..., 1] * (new_h / old_h)
  42. return coords
  43. def apply_boxes(
  44. self, boxes: np.ndarray, original_size: Tuple[int, ...]
  45. ) -> np.ndarray:
  46. """
  47. Expects a numpy array shape Bx4. Requires the original image size
  48. in (H, W) format.
  49. """
  50. boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
  51. return boxes.reshape(-1, 4)
  52. def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
  53. """
  54. Expects batched images with shape BxCxHxW and float format. This
  55. transformation may not exactly match apply_image. apply_image is
  56. the transformation expected by the model.
  57. """
  58. # Expects an image in BCHW format. May not exactly match apply_image.
  59. target_size = self.get_preprocess_shape(
  60. image.shape[0], image.shape[1], self.target_length
  61. )
  62. return F.interpolate(
  63. image, target_size, mode="bilinear", align_corners=False, antialias=True
  64. )
  65. def apply_coords_torch(
  66. self, coords: torch.Tensor, original_size: Tuple[int, ...]
  67. ) -> torch.Tensor:
  68. """
  69. Expects a torch tensor with length 2 in the last dimension. Requires the
  70. original image size in (H, W) format.
  71. """
  72. old_h, old_w = original_size
  73. new_h, new_w = self.get_preprocess_shape(
  74. original_size[0], original_size[1], self.target_length
  75. )
  76. coords = deepcopy(coords).to(torch.float)
  77. coords[..., 0] = coords[..., 0] * (new_w / old_w)
  78. coords[..., 1] = coords[..., 1] * (new_h / old_h)
  79. return coords
  80. def apply_boxes_torch(
  81. self, boxes: torch.Tensor, original_size: Tuple[int, ...]
  82. ) -> torch.Tensor:
  83. """
  84. Expects a torch tensor with shape Bx4. Requires the original image
  85. size in (H, W) format.
  86. """
  87. boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
  88. return boxes.reshape(-1, 4)
  89. @staticmethod
  90. def get_preprocess_shape(
  91. oldh: int, oldw: int, long_side_length: int
  92. ) -> Tuple[int, int]:
  93. """
  94. Compute the output size given input size and target long side length.
  95. """
  96. scale = long_side_length * 1.0 / max(oldh, oldw)
  97. newh, neww = oldh * scale, oldw * scale
  98. neww = int(neww + 0.5)
  99. newh = int(newh + 0.5)
  100. return (newh, neww)