image_encoder.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. from typing import Optional, Tuple, Type
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. from .common import LayerNorm2d, MLPBlock
  10. # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
  11. class ImageEncoderViT(nn.Module):
  12. def __init__(
  13. self,
  14. img_size: int = 1024,
  15. patch_size: int = 16,
  16. in_chans: int = 3,
  17. embed_dim: int = 768,
  18. depth: int = 12,
  19. num_heads: int = 12,
  20. mlp_ratio: float = 4.0,
  21. out_chans: int = 256,
  22. qkv_bias: bool = True,
  23. norm_layer: Type[nn.Module] = nn.LayerNorm,
  24. act_layer: Type[nn.Module] = nn.GELU,
  25. use_abs_pos: bool = True,
  26. use_rel_pos: bool = False,
  27. rel_pos_zero_init: bool = True,
  28. window_size: int = 0,
  29. global_attn_indexes: Tuple[int, ...] = (),
  30. ) -> None:
  31. """
  32. Args:
  33. img_size (int): Input image size.
  34. patch_size (int): Patch size.
  35. in_chans (int): Number of input image channels.
  36. embed_dim (int): Patch embedding dimension.
  37. depth (int): Depth of ViT.
  38. num_heads (int): Number of attention heads in each ViT block.
  39. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  40. qkv_bias (bool): If True, add a learnable bias to query, key, value.
  41. norm_layer (nn.Module): Normalization layer.
  42. act_layer (nn.Module): Activation layer.
  43. use_abs_pos (bool): If True, use absolute positional embeddings.
  44. use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
  45. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  46. window_size (int): Window size for window attention blocks.
  47. global_attn_indexes (list): Indexes for blocks using global attention.
  48. """
  49. super().__init__()
  50. self.img_size = img_size
  51. self.patch_embed = PatchEmbed(
  52. kernel_size=(patch_size, patch_size),
  53. stride=(patch_size, patch_size),
  54. in_chans=in_chans,
  55. embed_dim=embed_dim,
  56. )
  57. self.pos_embed: Optional[nn.Parameter] = None
  58. if use_abs_pos:
  59. # Initialize absolute positional embedding with pretrain image size.
  60. self.pos_embed = nn.Parameter(
  61. torch.zeros(
  62. 1, img_size // patch_size, img_size // patch_size, embed_dim
  63. )
  64. )
  65. self.blocks = nn.ModuleList()
  66. for i in range(depth):
  67. block = Block(
  68. dim=embed_dim,
  69. num_heads=num_heads,
  70. mlp_ratio=mlp_ratio,
  71. qkv_bias=qkv_bias,
  72. norm_layer=norm_layer,
  73. act_layer=act_layer,
  74. use_rel_pos=use_rel_pos,
  75. rel_pos_zero_init=rel_pos_zero_init,
  76. window_size=window_size if i not in global_attn_indexes else 0,
  77. input_size=(img_size // patch_size, img_size // patch_size),
  78. )
  79. self.blocks.append(block)
  80. self.neck = nn.Sequential(
  81. nn.Conv2d(
  82. embed_dim,
  83. out_chans,
  84. kernel_size=1,
  85. bias=False,
  86. ),
  87. LayerNorm2d(out_chans),
  88. nn.Conv2d(
  89. out_chans,
  90. out_chans,
  91. kernel_size=3,
  92. padding=1,
  93. bias=False,
  94. ),
  95. LayerNorm2d(out_chans),
  96. )
  97. def forward(self, x: torch.Tensor) -> torch.Tensor:
  98. x = self.patch_embed(x)
  99. if self.pos_embed is not None:
  100. x = x + self.pos_embed
  101. for blk in self.blocks:
  102. x = blk(x)
  103. x = self.neck(x.permute(0, 3, 1, 2))
  104. return x
  105. class Block(nn.Module):
  106. """Transformer blocks with support of window attention and residual propagation blocks"""
  107. def __init__(
  108. self,
  109. dim: int,
  110. num_heads: int,
  111. mlp_ratio: float = 4.0,
  112. qkv_bias: bool = True,
  113. norm_layer: Type[nn.Module] = nn.LayerNorm,
  114. act_layer: Type[nn.Module] = nn.GELU,
  115. use_rel_pos: bool = False,
  116. rel_pos_zero_init: bool = True,
  117. window_size: int = 0,
  118. input_size: Optional[Tuple[int, int]] = None,
  119. ) -> None:
  120. """
  121. Args:
  122. dim (int): Number of input channels.
  123. num_heads (int): Number of attention heads in each ViT block.
  124. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  125. qkv_bias (bool): If True, add a learnable bias to query, key, value.
  126. norm_layer (nn.Module): Normalization layer.
  127. act_layer (nn.Module): Activation layer.
  128. use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
  129. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  130. window_size (int): Window size for window attention blocks. If it equals 0, then
  131. use global attention.
  132. input_size (int or None): Input resolution for calculating the relative positional
  133. parameter size.
  134. """
  135. super().__init__()
  136. self.norm1 = norm_layer(dim)
  137. self.attn = Attention(
  138. dim,
  139. num_heads=num_heads,
  140. qkv_bias=qkv_bias,
  141. use_rel_pos=use_rel_pos,
  142. rel_pos_zero_init=rel_pos_zero_init,
  143. input_size=input_size if window_size == 0 else (window_size, window_size),
  144. )
  145. self.norm2 = norm_layer(dim)
  146. self.mlp = MLPBlock(
  147. embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
  148. )
  149. self.window_size = window_size
  150. def forward(self, x: torch.Tensor) -> torch.Tensor:
  151. shortcut = x
  152. x = self.norm1(x)
  153. # Window partition
  154. if self.window_size > 0:
  155. H, W = x.shape[1], x.shape[2]
  156. x, pad_hw = window_partition(x, self.window_size)
  157. x = self.attn(x)
  158. # Reverse window partition
  159. if self.window_size > 0:
  160. x = window_unpartition(x, self.window_size, pad_hw, (H, W))
  161. x = shortcut + x
  162. x = x + self.mlp(self.norm2(x))
  163. return x
  164. class Attention(nn.Module):
  165. """Multi-head Attention block with relative position embeddings."""
  166. def __init__(
  167. self,
  168. dim: int,
  169. num_heads: int = 8,
  170. qkv_bias: bool = True,
  171. use_rel_pos: bool = False,
  172. rel_pos_zero_init: bool = True,
  173. input_size: Optional[Tuple[int, int]] = None,
  174. ) -> None:
  175. """
  176. Args:
  177. dim (int): Number of input channels.
  178. num_heads (int): Number of attention heads.
  179. qkv_bias (bool: If True, add a learnable bias to query, key, value.
  180. rel_pos (bool): If True, add relative positional embeddings to the attention map.
  181. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  182. input_size (int or None): Input resolution for calculating the relative positional
  183. parameter size.
  184. """
  185. super().__init__()
  186. self.num_heads = num_heads
  187. head_dim = dim // num_heads
  188. self.scale = head_dim**-0.5
  189. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  190. self.proj = nn.Linear(dim, dim)
  191. self.use_rel_pos = use_rel_pos
  192. if self.use_rel_pos:
  193. assert (
  194. input_size is not None
  195. ), "Input size must be provided if using relative positional encoding."
  196. # initialize relative positional embeddings
  197. self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
  198. self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
  199. def forward(self, x: torch.Tensor) -> torch.Tensor:
  200. B, H, W, _ = x.shape
  201. # qkv with shape (3, B, nHead, H * W, C)
  202. qkv = (
  203. self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  204. )
  205. # q, k, v with shape (B * nHead, H * W, C)
  206. q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
  207. attn = (q * self.scale) @ k.transpose(-2, -1)
  208. if self.use_rel_pos:
  209. attn = add_decomposed_rel_pos(
  210. attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
  211. )
  212. attn = attn.softmax(dim=-1)
  213. x = (
  214. (attn @ v)
  215. .view(B, self.num_heads, H, W, -1)
  216. .permute(0, 2, 3, 1, 4)
  217. .reshape(B, H, W, -1)
  218. )
  219. x = self.proj(x)
  220. return x
  221. def window_partition(
  222. x: torch.Tensor, window_size: int
  223. ) -> Tuple[torch.Tensor, Tuple[int, int]]:
  224. """
  225. Partition into non-overlapping windows with padding if needed.
  226. Args:
  227. x (tensor): input tokens with [B, H, W, C].
  228. window_size (int): window size.
  229. Returns:
  230. windows: windows after partition with [B * num_windows, window_size, window_size, C].
  231. (Hp, Wp): padded height and width before partition
  232. """
  233. B, H, W, C = x.shape
  234. pad_h = (window_size - H % window_size) % window_size
  235. pad_w = (window_size - W % window_size) % window_size
  236. if pad_h > 0 or pad_w > 0:
  237. x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
  238. Hp, Wp = H + pad_h, W + pad_w
  239. x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
  240. windows = (
  241. x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  242. )
  243. return windows, (Hp, Wp)
  244. def window_unpartition(
  245. windows: torch.Tensor,
  246. window_size: int,
  247. pad_hw: Tuple[int, int],
  248. hw: Tuple[int, int],
  249. ) -> torch.Tensor:
  250. """
  251. Window unpartition into original sequences and removing padding.
  252. Args:
  253. x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
  254. window_size (int): window size.
  255. pad_hw (Tuple): padded height and width (Hp, Wp).
  256. hw (Tuple): original height and width (H, W) before padding.
  257. Returns:
  258. x: unpartitioned sequences with [B, H, W, C].
  259. """
  260. Hp, Wp = pad_hw
  261. H, W = hw
  262. B = windows.shape[0] // (Hp * Wp // window_size // window_size)
  263. x = windows.view(
  264. B, Hp // window_size, Wp // window_size, window_size, window_size, -1
  265. )
  266. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
  267. if Hp > H or Wp > W:
  268. x = x[:, :H, :W, :].contiguous()
  269. return x
  270. def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
  271. """
  272. Get relative positional embeddings according to the relative positions of
  273. query and key sizes.
  274. Args:
  275. q_size (int): size of query q.
  276. k_size (int): size of key k.
  277. rel_pos (Tensor): relative position embeddings (L, C).
  278. Returns:
  279. Extracted positional embeddings according to relative positions.
  280. """
  281. max_rel_dist = int(2 * max(q_size, k_size) - 1)
  282. # Interpolate rel pos if needed.
  283. if rel_pos.shape[0] != max_rel_dist:
  284. # Interpolate rel pos.
  285. rel_pos_resized = F.interpolate(
  286. rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
  287. size=max_rel_dist,
  288. mode="linear",
  289. )
  290. rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
  291. else:
  292. rel_pos_resized = rel_pos
  293. # Scale the coords with short length if shapes for q and k are different.
  294. q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
  295. k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
  296. relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
  297. return rel_pos_resized[relative_coords.long()]
  298. def add_decomposed_rel_pos(
  299. attn: torch.Tensor,
  300. q: torch.Tensor,
  301. rel_pos_h: torch.Tensor,
  302. rel_pos_w: torch.Tensor,
  303. q_size: Tuple[int, int],
  304. k_size: Tuple[int, int],
  305. ) -> torch.Tensor:
  306. """
  307. Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
  308. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
  309. Args:
  310. attn (Tensor): attention map.
  311. q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
  312. rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
  313. rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
  314. q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
  315. k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
  316. Returns:
  317. attn (Tensor): attention map with added relative positional embeddings.
  318. """
  319. q_h, q_w = q_size
  320. k_h, k_w = k_size
  321. Rh = get_rel_pos(q_h, k_h, rel_pos_h)
  322. Rw = get_rel_pos(q_w, k_w, rel_pos_w)
  323. B, _, dim = q.shape
  324. r_q = q.reshape(B, q_h, q_w, dim)
  325. rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
  326. rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
  327. attn = (
  328. attn.view(B, q_h, q_w, k_h, k_w)
  329. + rel_h[:, :, :, :, None]
  330. + rel_w[:, :, :, None, :]
  331. ).view(B, q_h * q_w, k_h * k_w)
  332. return attn
  333. class PatchEmbed(nn.Module):
  334. """
  335. Image to Patch Embedding.
  336. """
  337. def __init__(
  338. self,
  339. kernel_size: Tuple[int, int] = (16, 16),
  340. stride: Tuple[int, int] = (16, 16),
  341. padding: Tuple[int, int] = (0, 0),
  342. in_chans: int = 3,
  343. embed_dim: int = 768,
  344. ) -> None:
  345. """
  346. Args:
  347. kernel_size (Tuple): kernel size of the projection layer.
  348. stride (Tuple): stride of the projection layer.
  349. padding (Tuple): padding size of the projection layer.
  350. in_chans (int): Number of input image channels.
  351. embed_dim (int): embed_dim (int): Patch embedding dimension.
  352. """
  353. super().__init__()
  354. self.proj = nn.Conv2d(
  355. in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
  356. )
  357. def forward(self, x: torch.Tensor) -> torch.Tensor:
  358. x = self.proj(x)
  359. # B C H W -> B H W C
  360. x = x.permute(0, 2, 3, 1)
  361. return x