transformer.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import math
  6. from typing import Tuple, Type
  7. import torch
  8. from torch import Tensor, nn
  9. from .common import MLPBlock
  10. class TwoWayTransformer(nn.Module):
  11. def __init__(
  12. self,
  13. depth: int,
  14. embedding_dim: int,
  15. num_heads: int,
  16. mlp_dim: int,
  17. activation: Type[nn.Module] = nn.ReLU,
  18. attention_downsample_rate: int = 2,
  19. ) -> None:
  20. """
  21. A transformer decoder that attends to an input image using
  22. queries whose positional embedding is supplied.
  23. Args:
  24. depth (int): number of layers in the transformer
  25. embedding_dim (int): the channel dimension for the input embeddings
  26. num_heads (int): the number of heads for multihead attention. Must
  27. divide embedding_dim
  28. mlp_dim (int): the channel dimension internal to the MLP block
  29. activation (nn.Module): the activation to use in the MLP block
  30. """
  31. super().__init__()
  32. self.depth = depth
  33. self.embedding_dim = embedding_dim
  34. self.num_heads = num_heads
  35. self.mlp_dim = mlp_dim
  36. self.layers = nn.ModuleList()
  37. for i in range(depth):
  38. self.layers.append(
  39. TwoWayAttentionBlock(
  40. embedding_dim=embedding_dim,
  41. num_heads=num_heads,
  42. mlp_dim=mlp_dim,
  43. activation=activation,
  44. attention_downsample_rate=attention_downsample_rate,
  45. skip_first_layer_pe=(i == 0),
  46. )
  47. )
  48. self.final_attn_token_to_image = Attention(
  49. embedding_dim, num_heads, downsample_rate=attention_downsample_rate
  50. )
  51. self.norm_final_attn = nn.LayerNorm(embedding_dim)
  52. def forward(
  53. self,
  54. image_embedding: Tensor,
  55. image_pe: Tensor,
  56. point_embedding: Tensor,
  57. ) -> Tuple[Tensor, Tensor]:
  58. """
  59. Args:
  60. image_embedding (torch.Tensor): image to attend to. Should be shape
  61. B x embedding_dim x h x w for any h and w.
  62. image_pe (torch.Tensor): the positional encoding to add to the image. Must
  63. have the same shape as image_embedding.
  64. point_embedding (torch.Tensor): the embedding to add to the query points.
  65. Must have shape B x N_points x embedding_dim for any N_points.
  66. Returns:
  67. torch.Tensor: the processed point_embedding
  68. torch.Tensor: the processed image_embedding
  69. """
  70. # BxCxHxW -> BxHWxC == B x N_image_tokens x C
  71. bs, c, h, w = image_embedding.shape
  72. image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
  73. image_pe = image_pe.flatten(2).permute(0, 2, 1)
  74. # Prepare queries
  75. queries = point_embedding
  76. keys = image_embedding
  77. # Apply transformer blocks and final layernorm
  78. for layer in self.layers:
  79. queries, keys = layer(
  80. queries=queries,
  81. keys=keys,
  82. query_pe=point_embedding,
  83. key_pe=image_pe,
  84. )
  85. # Apply the final attenion layer from the points to the image
  86. q = queries + point_embedding
  87. k = keys + image_pe
  88. attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
  89. queries = queries + attn_out
  90. queries = self.norm_final_attn(queries)
  91. return queries, keys
  92. class TwoWayAttentionBlock(nn.Module):
  93. def __init__(
  94. self,
  95. embedding_dim: int,
  96. num_heads: int,
  97. mlp_dim: int = 2048,
  98. activation: Type[nn.Module] = nn.ReLU,
  99. attention_downsample_rate: int = 2,
  100. skip_first_layer_pe: bool = False,
  101. ) -> None:
  102. """
  103. A transformer block with four layers: (1) self-attention of sparse
  104. inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
  105. block on sparse inputs, and (4) cross attention of dense inputs to sparse
  106. inputs.
  107. Arguments:
  108. embedding_dim (int): the channel dimension of the embeddings
  109. num_heads (int): the number of heads in the attention layers
  110. mlp_dim (int): the hidden dimension of the mlp block
  111. activation (nn.Module): the activation of the mlp block
  112. skip_first_layer_pe (bool): skip the PE on the first layer
  113. """
  114. super().__init__()
  115. self.self_attn = Attention(embedding_dim, num_heads)
  116. self.norm1 = nn.LayerNorm(embedding_dim)
  117. self.cross_attn_token_to_image = Attention(
  118. embedding_dim, num_heads, downsample_rate=attention_downsample_rate
  119. )
  120. self.norm2 = nn.LayerNorm(embedding_dim)
  121. self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
  122. self.norm3 = nn.LayerNorm(embedding_dim)
  123. self.norm4 = nn.LayerNorm(embedding_dim)
  124. self.cross_attn_image_to_token = Attention(
  125. embedding_dim, num_heads, downsample_rate=attention_downsample_rate
  126. )
  127. self.skip_first_layer_pe = skip_first_layer_pe
  128. def forward(
  129. self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
  130. ) -> Tuple[Tensor, Tensor]:
  131. # Self attention block
  132. if self.skip_first_layer_pe:
  133. queries = self.self_attn(q=queries, k=queries, v=queries)
  134. else:
  135. q = queries + query_pe
  136. attn_out = self.self_attn(q=q, k=q, v=queries)
  137. queries = queries + attn_out
  138. queries = self.norm1(queries)
  139. # Cross attention block, tokens attending to image embedding
  140. q = queries + query_pe
  141. k = keys + key_pe
  142. attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
  143. queries = queries + attn_out
  144. queries = self.norm2(queries)
  145. # MLP block
  146. mlp_out = self.mlp(queries)
  147. queries = queries + mlp_out
  148. queries = self.norm3(queries)
  149. # Cross attention block, image embedding attending to tokens
  150. q = queries + query_pe
  151. k = keys + key_pe
  152. attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
  153. keys = keys + attn_out
  154. keys = self.norm4(keys)
  155. return queries, keys
  156. class Attention(nn.Module):
  157. """
  158. An attention layer that allows for downscaling the size of the embedding
  159. after projection to queries, keys, and values.
  160. """
  161. def __init__(
  162. self,
  163. embedding_dim: int,
  164. num_heads: int,
  165. downsample_rate: int = 1,
  166. ) -> None:
  167. super().__init__()
  168. self.embedding_dim = embedding_dim
  169. self.internal_dim = embedding_dim // downsample_rate
  170. self.num_heads = num_heads
  171. assert (
  172. self.internal_dim % num_heads == 0
  173. ), "num_heads must divide embedding_dim."
  174. self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
  175. self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
  176. self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
  177. self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
  178. def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
  179. b, n, c = x.shape
  180. x = x.reshape(b, n, num_heads, c // num_heads)
  181. return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
  182. def _recombine_heads(self, x: Tensor) -> Tensor:
  183. b, n_heads, n_tokens, c_per_head = x.shape
  184. x = x.transpose(1, 2)
  185. return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
  186. def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
  187. # Input projections
  188. q = self.q_proj(q)
  189. k = self.k_proj(k)
  190. v = self.v_proj(v)
  191. # Separate into heads
  192. q = self._separate_heads(q, self.num_heads)
  193. k = self._separate_heads(k, self.num_heads)
  194. v = self._separate_heads(v, self.num_heads)
  195. # Attention
  196. _, _, _, c_per_head = q.shape
  197. attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
  198. attn = attn / math.sqrt(c_per_head)
  199. attn = torch.softmax(attn, dim=-1)
  200. # Get output
  201. out = attn @ v
  202. out = self._recombine_heads(out)
  203. out = self.out_proj(out)
  204. return out