| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import logging
- from functools import partial
- from typing import List, Tuple, Union
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from ..sam2_utils import MLP, DropPath
- from .utils import PatchEmbed, window_partition, window_unpartition
- def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
- if pool is None:
- return x
- # (B, H, W, C) -> (B, C, H, W)
- x = x.permute(0, 3, 1, 2)
- x = pool(x)
- # (B, C, H', W') -> (B, H', W', C)
- x = x.permute(0, 2, 3, 1)
- if norm:
- x = norm(x)
- return x
- class MultiScaleAttention(nn.Module):
- def __init__(
- self,
- dim: int,
- dim_out: int,
- num_heads: int,
- q_pool: nn.Module = None,
- ):
- super().__init__()
- self.dim = dim
- self.dim_out = dim_out
- self.num_heads = num_heads
- self.q_pool = q_pool
- self.qkv = nn.Linear(dim, dim_out * 3)
- self.proj = nn.Linear(dim_out, dim_out)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- B, H, W, _ = x.shape
- # qkv with shape (B, H * W, 3, nHead, C)
- qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
- # q, k, v with shape (B, H * W, nheads, C)
- q, k, v = torch.unbind(qkv, 2)
- # Q pooling (for downsample at stage changes)
- if self.q_pool:
- q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
- H, W = q.shape[1:3] # downsampled shape
- q = q.reshape(B, H * W, self.num_heads, -1)
- # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
- x = F.scaled_dot_product_attention(
- q.transpose(1, 2),
- k.transpose(1, 2),
- v.transpose(1, 2),
- )
- # Transpose back
- x = x.transpose(1, 2)
- x = x.reshape(B, H, W, -1)
- x = self.proj(x)
- return x
- class MultiScaleBlock(nn.Module):
- def __init__(
- self,
- dim: int,
- dim_out: int,
- num_heads: int,
- mlp_ratio: float = 4.0,
- drop_path: float = 0.0,
- norm_layer: Union[nn.Module, str] = "LayerNorm",
- q_stride: Tuple[int, int] = None,
- act_layer: nn.Module = nn.GELU,
- window_size: int = 0,
- ):
- super().__init__()
- if isinstance(norm_layer, str):
- norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
- self.dim = dim
- self.dim_out = dim_out
- self.norm1 = norm_layer(dim)
- self.window_size = window_size
- self.pool, self.q_stride = None, q_stride
- if self.q_stride:
- self.pool = nn.MaxPool2d(
- kernel_size=q_stride, stride=q_stride, ceil_mode=False
- )
- self.attn = MultiScaleAttention(
- dim,
- dim_out,
- num_heads=num_heads,
- q_pool=self.pool,
- )
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- self.norm2 = norm_layer(dim_out)
- self.mlp = MLP(
- dim_out,
- int(dim_out * mlp_ratio),
- dim_out,
- num_layers=2,
- activation=act_layer,
- )
- if dim != dim_out:
- self.proj = nn.Linear(dim, dim_out)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- shortcut = x # B, H, W, C
- x = self.norm1(x)
- # Skip connection
- if self.dim != self.dim_out:
- shortcut = do_pool(self.proj(x), self.pool)
- # Window partition
- window_size = self.window_size
- if window_size > 0:
- H, W = x.shape[1], x.shape[2]
- x, pad_hw = window_partition(x, window_size)
- # Window Attention + Q Pooling (if stage change)
- x = self.attn(x)
- if self.q_stride:
- # Shapes have changed due to Q pooling
- window_size = self.window_size // self.q_stride[0]
- H, W = shortcut.shape[1:3]
- pad_h = (window_size - H % window_size) % window_size
- pad_w = (window_size - W % window_size) % window_size
- pad_hw = (H + pad_h, W + pad_w)
- # Reverse window partition
- if self.window_size > 0:
- x = window_unpartition(x, window_size, pad_hw, (H, W))
- x = shortcut + self.drop_path(x)
- # MLP
- x = x + self.drop_path(self.mlp(self.norm2(x)))
- return x
- class Hiera(nn.Module):
- """
- Reference: https://arxiv.org/abs/2306.00989
- """
- def __init__(
- self,
- embed_dim: int = 96, # initial embed dim
- num_heads: int = 1, # initial number of heads
- drop_path_rate: float = 0.0, # stochastic depth
- q_pool: int = 3, # number of q_pool stages
- q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
- stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
- dim_mul: float = 2.0, # dim_mul factor at stage shift
- head_mul: float = 2.0, # head_mul factor at stage shift
- window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
- # window size per stage, when not using global att.
- window_spec: Tuple[int, ...] = (
- 8,
- 4,
- 14,
- 7,
- ),
- # global attn in these blocks
- global_att_blocks: Tuple[int, ...] = (
- 12,
- 16,
- 20,
- ),
- weights_path=None,
- return_interm_layers=True, # return feats from every stage
- ):
- super().__init__()
- assert len(stages) == len(window_spec)
- self.window_spec = window_spec
- depth = sum(stages)
- self.q_stride = q_stride
- self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
- assert 0 <= q_pool <= len(self.stage_ends[:-1])
- self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
- self.return_interm_layers = return_interm_layers
- self.patch_embed = PatchEmbed(
- embed_dim=embed_dim,
- )
- # Which blocks have global att?
- self.global_att_blocks = global_att_blocks
- # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
- self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
- self.pos_embed = nn.Parameter(
- torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
- )
- self.pos_embed_window = nn.Parameter(
- torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
- )
- dpr = [
- x.item() for x in torch.linspace(0, drop_path_rate, depth)
- ] # stochastic depth decay rule
- cur_stage = 1
- self.blocks = nn.ModuleList()
- for i in range(depth):
- dim_out = embed_dim
- # lags by a block, so first block of
- # next stage uses an initial window size
- # of previous stage and final window size of current stage
- window_size = self.window_spec[cur_stage - 1]
- if self.global_att_blocks is not None:
- window_size = 0 if i in self.global_att_blocks else window_size
- if i - 1 in self.stage_ends:
- dim_out = int(embed_dim * dim_mul)
- num_heads = int(num_heads * head_mul)
- cur_stage += 1
- block = MultiScaleBlock(
- dim=embed_dim,
- dim_out=dim_out,
- num_heads=num_heads,
- drop_path=dpr[i],
- q_stride=self.q_stride if i in self.q_pool_blocks else None,
- window_size=window_size,
- )
- embed_dim = dim_out
- self.blocks.append(block)
- self.channel_list = (
- [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
- if return_interm_layers
- else [self.blocks[-1].dim_out]
- )
- if weights_path is not None:
- chkpt = torch.load(weights_path, map_location="cpu")
- logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
- def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
- h, w = hw
- window_embed = self.pos_embed_window
- pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
- pos_embed = pos_embed + window_embed.tile(
- [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
- )
- pos_embed = pos_embed.permute(0, 2, 3, 1)
- return pos_embed
- def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
- x = self.patch_embed(x)
- # x: (B, H, W, C)
- # Add pos embed
- x = x + self._get_pos_embed(x.shape[1:3])
- outputs = []
- for i, blk in enumerate(self.blocks):
- x = blk(x)
- if (i == self.stage_ends[-1]) or (
- i in self.stage_ends and self.return_interm_layers
- ):
- feats = x.permute(0, 3, 1, 2)
- outputs.append(feats)
- return outputs
- def get_layer_id(self, layer_name):
- # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
- num_layers = self.get_num_layers()
- if layer_name.find("rel_pos") != -1:
- return num_layers + 1
- elif layer_name.find("pos_embed") != -1:
- return 0
- elif layer_name.find("patch_embed") != -1:
- return 0
- elif layer_name.find("blocks") != -1:
- return int(layer_name.split("blocks")[1].split(".")[1]) + 1
- else:
- return num_layers + 1
- def get_num_layers(self) -> int:
- return len(self.blocks)
|