hieradet.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import logging
  6. from functools import partial
  7. from typing import List, Tuple, Union
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from ..sam2_utils import MLP, DropPath
  12. from .utils import PatchEmbed, window_partition, window_unpartition
  13. def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
  14. if pool is None:
  15. return x
  16. # (B, H, W, C) -> (B, C, H, W)
  17. x = x.permute(0, 3, 1, 2)
  18. x = pool(x)
  19. # (B, C, H', W') -> (B, H', W', C)
  20. x = x.permute(0, 2, 3, 1)
  21. if norm:
  22. x = norm(x)
  23. return x
  24. class MultiScaleAttention(nn.Module):
  25. def __init__(
  26. self,
  27. dim: int,
  28. dim_out: int,
  29. num_heads: int,
  30. q_pool: nn.Module = None,
  31. ):
  32. super().__init__()
  33. self.dim = dim
  34. self.dim_out = dim_out
  35. self.num_heads = num_heads
  36. self.q_pool = q_pool
  37. self.qkv = nn.Linear(dim, dim_out * 3)
  38. self.proj = nn.Linear(dim_out, dim_out)
  39. def forward(self, x: torch.Tensor) -> torch.Tensor:
  40. B, H, W, _ = x.shape
  41. # qkv with shape (B, H * W, 3, nHead, C)
  42. qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
  43. # q, k, v with shape (B, H * W, nheads, C)
  44. q, k, v = torch.unbind(qkv, 2)
  45. # Q pooling (for downsample at stage changes)
  46. if self.q_pool:
  47. q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
  48. H, W = q.shape[1:3] # downsampled shape
  49. q = q.reshape(B, H * W, self.num_heads, -1)
  50. # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
  51. x = F.scaled_dot_product_attention(
  52. q.transpose(1, 2),
  53. k.transpose(1, 2),
  54. v.transpose(1, 2),
  55. )
  56. # Transpose back
  57. x = x.transpose(1, 2)
  58. x = x.reshape(B, H, W, -1)
  59. x = self.proj(x)
  60. return x
  61. class MultiScaleBlock(nn.Module):
  62. def __init__(
  63. self,
  64. dim: int,
  65. dim_out: int,
  66. num_heads: int,
  67. mlp_ratio: float = 4.0,
  68. drop_path: float = 0.0,
  69. norm_layer: Union[nn.Module, str] = "LayerNorm",
  70. q_stride: Tuple[int, int] = None,
  71. act_layer: nn.Module = nn.GELU,
  72. window_size: int = 0,
  73. ):
  74. super().__init__()
  75. if isinstance(norm_layer, str):
  76. norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
  77. self.dim = dim
  78. self.dim_out = dim_out
  79. self.norm1 = norm_layer(dim)
  80. self.window_size = window_size
  81. self.pool, self.q_stride = None, q_stride
  82. if self.q_stride:
  83. self.pool = nn.MaxPool2d(
  84. kernel_size=q_stride, stride=q_stride, ceil_mode=False
  85. )
  86. self.attn = MultiScaleAttention(
  87. dim,
  88. dim_out,
  89. num_heads=num_heads,
  90. q_pool=self.pool,
  91. )
  92. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  93. self.norm2 = norm_layer(dim_out)
  94. self.mlp = MLP(
  95. dim_out,
  96. int(dim_out * mlp_ratio),
  97. dim_out,
  98. num_layers=2,
  99. activation=act_layer,
  100. )
  101. if dim != dim_out:
  102. self.proj = nn.Linear(dim, dim_out)
  103. def forward(self, x: torch.Tensor) -> torch.Tensor:
  104. shortcut = x # B, H, W, C
  105. x = self.norm1(x)
  106. # Skip connection
  107. if self.dim != self.dim_out:
  108. shortcut = do_pool(self.proj(x), self.pool)
  109. # Window partition
  110. window_size = self.window_size
  111. if window_size > 0:
  112. H, W = x.shape[1], x.shape[2]
  113. x, pad_hw = window_partition(x, window_size)
  114. # Window Attention + Q Pooling (if stage change)
  115. x = self.attn(x)
  116. if self.q_stride:
  117. # Shapes have changed due to Q pooling
  118. window_size = self.window_size // self.q_stride[0]
  119. H, W = shortcut.shape[1:3]
  120. pad_h = (window_size - H % window_size) % window_size
  121. pad_w = (window_size - W % window_size) % window_size
  122. pad_hw = (H + pad_h, W + pad_w)
  123. # Reverse window partition
  124. if self.window_size > 0:
  125. x = window_unpartition(x, window_size, pad_hw, (H, W))
  126. x = shortcut + self.drop_path(x)
  127. # MLP
  128. x = x + self.drop_path(self.mlp(self.norm2(x)))
  129. return x
  130. class Hiera(nn.Module):
  131. """
  132. Reference: https://arxiv.org/abs/2306.00989
  133. """
  134. def __init__(
  135. self,
  136. embed_dim: int = 96, # initial embed dim
  137. num_heads: int = 1, # initial number of heads
  138. drop_path_rate: float = 0.0, # stochastic depth
  139. q_pool: int = 3, # number of q_pool stages
  140. q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
  141. stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
  142. dim_mul: float = 2.0, # dim_mul factor at stage shift
  143. head_mul: float = 2.0, # head_mul factor at stage shift
  144. window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
  145. # window size per stage, when not using global att.
  146. window_spec: Tuple[int, ...] = (
  147. 8,
  148. 4,
  149. 14,
  150. 7,
  151. ),
  152. # global attn in these blocks
  153. global_att_blocks: Tuple[int, ...] = (
  154. 12,
  155. 16,
  156. 20,
  157. ),
  158. weights_path=None,
  159. return_interm_layers=True, # return feats from every stage
  160. ):
  161. super().__init__()
  162. assert len(stages) == len(window_spec)
  163. self.window_spec = window_spec
  164. depth = sum(stages)
  165. self.q_stride = q_stride
  166. self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
  167. assert 0 <= q_pool <= len(self.stage_ends[:-1])
  168. self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
  169. self.return_interm_layers = return_interm_layers
  170. self.patch_embed = PatchEmbed(
  171. embed_dim=embed_dim,
  172. )
  173. # Which blocks have global att?
  174. self.global_att_blocks = global_att_blocks
  175. # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
  176. self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
  177. self.pos_embed = nn.Parameter(
  178. torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
  179. )
  180. self.pos_embed_window = nn.Parameter(
  181. torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
  182. )
  183. dpr = [
  184. x.item() for x in torch.linspace(0, drop_path_rate, depth)
  185. ] # stochastic depth decay rule
  186. cur_stage = 1
  187. self.blocks = nn.ModuleList()
  188. for i in range(depth):
  189. dim_out = embed_dim
  190. # lags by a block, so first block of
  191. # next stage uses an initial window size
  192. # of previous stage and final window size of current stage
  193. window_size = self.window_spec[cur_stage - 1]
  194. if self.global_att_blocks is not None:
  195. window_size = 0 if i in self.global_att_blocks else window_size
  196. if i - 1 in self.stage_ends:
  197. dim_out = int(embed_dim * dim_mul)
  198. num_heads = int(num_heads * head_mul)
  199. cur_stage += 1
  200. block = MultiScaleBlock(
  201. dim=embed_dim,
  202. dim_out=dim_out,
  203. num_heads=num_heads,
  204. drop_path=dpr[i],
  205. q_stride=self.q_stride if i in self.q_pool_blocks else None,
  206. window_size=window_size,
  207. )
  208. embed_dim = dim_out
  209. self.blocks.append(block)
  210. self.channel_list = (
  211. [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
  212. if return_interm_layers
  213. else [self.blocks[-1].dim_out]
  214. )
  215. if weights_path is not None:
  216. chkpt = torch.load(weights_path, map_location="cpu")
  217. logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
  218. def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
  219. h, w = hw
  220. window_embed = self.pos_embed_window
  221. pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
  222. pos_embed = pos_embed + window_embed.tile(
  223. [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
  224. )
  225. pos_embed = pos_embed.permute(0, 2, 3, 1)
  226. return pos_embed
  227. def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
  228. x = self.patch_embed(x)
  229. # x: (B, H, W, C)
  230. # Add pos embed
  231. x = x + self._get_pos_embed(x.shape[1:3])
  232. outputs = []
  233. for i, blk in enumerate(self.blocks):
  234. x = blk(x)
  235. if (i == self.stage_ends[-1]) or (
  236. i in self.stage_ends and self.return_interm_layers
  237. ):
  238. feats = x.permute(0, 3, 1, 2)
  239. outputs.append(feats)
  240. return outputs
  241. def get_layer_id(self, layer_name):
  242. # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
  243. num_layers = self.get_num_layers()
  244. if layer_name.find("rel_pos") != -1:
  245. return num_layers + 1
  246. elif layer_name.find("pos_embed") != -1:
  247. return 0
  248. elif layer_name.find("patch_embed") != -1:
  249. return 0
  250. elif layer_name.find("blocks") != -1:
  251. return int(layer_name.split("blocks")[1].split(".")[1]) + 1
  252. else:
  253. return num_layers + 1
  254. def get_num_layers(self) -> int:
  255. return len(self.blocks)