| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- """Some utilities for backbones, in particular for windowing"""
- from typing import Tuple
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- def window_partition(x, window_size):
- """
- Partition into non-overlapping windows with padding if needed.
- Args:
- x (tensor): input tokens with [B, H, W, C].
- window_size (int): window size.
- Returns:
- windows: windows after partition with [B * num_windows, window_size, window_size, C].
- (Hp, Wp): padded height and width before partition
- """
- B, H, W, C = x.shape
- pad_h = (window_size - H % window_size) % window_size
- pad_w = (window_size - W % window_size) % window_size
- if pad_h > 0 or pad_w > 0:
- x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
- Hp, Wp = H + pad_h, W + pad_w
- x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
- windows = (
- x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
- )
- return windows, (Hp, Wp)
- def window_unpartition(windows, window_size, pad_hw, hw):
- """
- Window unpartition into original sequences and removing padding.
- Args:
- x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
- window_size (int): window size.
- pad_hw (Tuple): padded height and width (Hp, Wp).
- hw (Tuple): original height and width (H, W) before padding.
- Returns:
- x: unpartitioned sequences with [B, H, W, C].
- """
- Hp, Wp = pad_hw
- H, W = hw
- B = windows.shape[0] // (Hp * Wp // window_size // window_size)
- x = windows.view(
- B, Hp // window_size, Wp // window_size, window_size, window_size, -1
- )
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
- if Hp > H or Wp > W:
- x = x[:, :H, :W, :].contiguous()
- return x
- class PatchEmbed(nn.Module):
- """
- Image to Patch Embedding.
- """
- def __init__(
- self,
- kernel_size: Tuple[int, ...] = (7, 7),
- stride: Tuple[int, ...] = (4, 4),
- padding: Tuple[int, ...] = (3, 3),
- in_chans: int = 3,
- embed_dim: int = 768,
- ):
- """
- Args:
- kernel_size (Tuple): kernel size of the projection layer.
- stride (Tuple): stride of the projection layer.
- padding (Tuple): padding size of the projection layer.
- in_chans (int): Number of input image channels.
- embed_dim (int): embed_dim (int): Patch embedding dimension.
- """
- super().__init__()
- self.proj = nn.Conv2d(
- in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
- )
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self.proj(x)
- # B C H W -> B H W C
- x = x.permute(0, 2, 3, 1)
- return x
|