utils.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. """Some utilities for backbones, in particular for windowing"""
  6. from typing import Tuple
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. def window_partition(x, window_size):
  11. """
  12. Partition into non-overlapping windows with padding if needed.
  13. Args:
  14. x (tensor): input tokens with [B, H, W, C].
  15. window_size (int): window size.
  16. Returns:
  17. windows: windows after partition with [B * num_windows, window_size, window_size, C].
  18. (Hp, Wp): padded height and width before partition
  19. """
  20. B, H, W, C = x.shape
  21. pad_h = (window_size - H % window_size) % window_size
  22. pad_w = (window_size - W % window_size) % window_size
  23. if pad_h > 0 or pad_w > 0:
  24. x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
  25. Hp, Wp = H + pad_h, W + pad_w
  26. x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
  27. windows = (
  28. x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  29. )
  30. return windows, (Hp, Wp)
  31. def window_unpartition(windows, window_size, pad_hw, hw):
  32. """
  33. Window unpartition into original sequences and removing padding.
  34. Args:
  35. x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
  36. window_size (int): window size.
  37. pad_hw (Tuple): padded height and width (Hp, Wp).
  38. hw (Tuple): original height and width (H, W) before padding.
  39. Returns:
  40. x: unpartitioned sequences with [B, H, W, C].
  41. """
  42. Hp, Wp = pad_hw
  43. H, W = hw
  44. B = windows.shape[0] // (Hp * Wp // window_size // window_size)
  45. x = windows.view(
  46. B, Hp // window_size, Wp // window_size, window_size, window_size, -1
  47. )
  48. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
  49. if Hp > H or Wp > W:
  50. x = x[:, :H, :W, :].contiguous()
  51. return x
  52. class PatchEmbed(nn.Module):
  53. """
  54. Image to Patch Embedding.
  55. """
  56. def __init__(
  57. self,
  58. kernel_size: Tuple[int, ...] = (7, 7),
  59. stride: Tuple[int, ...] = (4, 4),
  60. padding: Tuple[int, ...] = (3, 3),
  61. in_chans: int = 3,
  62. embed_dim: int = 768,
  63. ):
  64. """
  65. Args:
  66. kernel_size (Tuple): kernel size of the projection layer.
  67. stride (Tuple): stride of the projection layer.
  68. padding (Tuple): padding size of the projection layer.
  69. in_chans (int): Number of input image channels.
  70. embed_dim (int): embed_dim (int): Patch embedding dimension.
  71. """
  72. super().__init__()
  73. self.proj = nn.Conv2d(
  74. in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
  75. )
  76. def forward(self, x: torch.Tensor) -> torch.Tensor:
  77. x = self.proj(x)
  78. # B C H W -> B H W C
  79. x = x.permute(0, 2, 3, 1)
  80. return x