encoder.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. from functools import partial
  2. import torch
  3. import torch.nn.functional as F
  4. from torch import nn
  5. def drop_path(
  6. x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
  7. ):
  8. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  9. This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
  10. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  11. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
  12. changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
  13. 'survival rate' as the argument.
  14. """ # noqa: E501
  15. if drop_prob == 0.0 or not training:
  16. return x
  17. keep_prob = 1 - drop_prob
  18. shape = (x.shape[0],) + (1,) * (
  19. x.ndim - 1
  20. ) # work with diff dim tensors, not just 2D ConvNets
  21. random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
  22. if keep_prob > 0.0 and scale_by_keep:
  23. random_tensor.div_(keep_prob)
  24. return x * random_tensor
  25. class DropPath(nn.Module):
  26. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
  27. def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
  28. super(DropPath, self).__init__()
  29. self.drop_prob = drop_prob
  30. self.scale_by_keep = scale_by_keep
  31. def forward(self, x):
  32. return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
  33. def extra_repr(self):
  34. return f"drop_prob={round(self.drop_prob,3):0.3f}"
  35. class LayerNorm(nn.Module):
  36. r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
  37. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
  38. shape (batch_size, height, width, channels) while channels_first corresponds to inputs
  39. with shape (batch_size, channels, height, width).
  40. """ # noqa: E501
  41. def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
  42. super().__init__()
  43. self.weight = nn.Parameter(torch.ones(normalized_shape))
  44. self.bias = nn.Parameter(torch.zeros(normalized_shape))
  45. self.eps = eps
  46. self.data_format = data_format
  47. if self.data_format not in ["channels_last", "channels_first"]:
  48. raise NotImplementedError
  49. self.normalized_shape = (normalized_shape,)
  50. def forward(self, x):
  51. if self.data_format == "channels_last":
  52. return F.layer_norm(
  53. x, self.normalized_shape, self.weight, self.bias, self.eps
  54. )
  55. elif self.data_format == "channels_first":
  56. u = x.mean(1, keepdim=True)
  57. s = (x - u).pow(2).mean(1, keepdim=True)
  58. x = (x - u) / torch.sqrt(s + self.eps)
  59. x = self.weight[:, None] * x + self.bias[:, None]
  60. return x
  61. class ConvNeXtBlock(nn.Module):
  62. r"""ConvNeXt Block. There are two equivalent implementations:
  63. (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
  64. (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
  65. We use (2) as we find it slightly faster in PyTorch
  66. Args:
  67. dim (int): Number of input channels.
  68. drop_path (float): Stochastic depth rate. Default: 0.0
  69. layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
  70. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
  71. kernel_size (int): Kernel size for depthwise conv. Default: 7.
  72. dilation (int): Dilation for depthwise conv. Default: 1.
  73. """ # noqa: E501
  74. def __init__(
  75. self,
  76. dim: int,
  77. drop_path: float = 0.0,
  78. layer_scale_init_value: float = 1e-6,
  79. mlp_ratio: float = 4.0,
  80. kernel_size: int = 7,
  81. dilation: int = 1,
  82. ):
  83. super().__init__()
  84. self.dwconv = nn.Conv1d(
  85. dim,
  86. dim,
  87. kernel_size=kernel_size,
  88. padding=int(dilation * (kernel_size - 1) / 2),
  89. groups=dim,
  90. ) # depthwise conv
  91. self.norm = LayerNorm(dim, eps=1e-6)
  92. self.pwconv1 = nn.Linear(
  93. dim, int(mlp_ratio * dim)
  94. ) # pointwise/1x1 convs, implemented with linear layers
  95. self.act = nn.GELU()
  96. self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
  97. self.gamma = (
  98. nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  99. if layer_scale_init_value > 0
  100. else None
  101. )
  102. self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  103. def forward(self, x, apply_residual: bool = True):
  104. input = x
  105. x = self.dwconv(x)
  106. x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
  107. x = self.norm(x)
  108. x = self.pwconv1(x)
  109. x = self.act(x)
  110. x = self.pwconv2(x)
  111. if self.gamma is not None:
  112. x = self.gamma * x
  113. x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
  114. x = self.drop_path(x)
  115. if apply_residual:
  116. x = input + x
  117. return x
  118. class ParallelConvNeXtBlock(nn.Module):
  119. def __init__(self, kernel_sizes: list[int], *args, **kwargs):
  120. super().__init__()
  121. self.blocks = nn.ModuleList(
  122. [
  123. ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
  124. for kernel_size in kernel_sizes
  125. ]
  126. )
  127. def forward(self, x: torch.Tensor) -> torch.Tensor:
  128. return torch.stack(
  129. [block(x, apply_residual=False) for block in self.blocks] + [x],
  130. dim=1,
  131. ).sum(dim=1)
  132. class ConvNeXtEncoder(nn.Module):
  133. def __init__(
  134. self,
  135. input_channels=3,
  136. depths=[3, 3, 9, 3],
  137. dims=[96, 192, 384, 768],
  138. drop_path_rate=0.0,
  139. layer_scale_init_value=1e-6,
  140. kernel_sizes: tuple[int] = (7,),
  141. ):
  142. super().__init__()
  143. assert len(depths) == len(dims)
  144. self.channel_layers = nn.ModuleList()
  145. stem = nn.Sequential(
  146. nn.Conv1d(
  147. input_channels,
  148. dims[0],
  149. kernel_size=7,
  150. padding=3,
  151. padding_mode="replicate",
  152. ),
  153. LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
  154. )
  155. self.channel_layers.append(stem)
  156. for i in range(len(depths) - 1):
  157. mid_layer = nn.Sequential(
  158. LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
  159. nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
  160. )
  161. self.channel_layers.append(mid_layer)
  162. block_fn = (
  163. partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
  164. if len(kernel_sizes) == 1
  165. else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
  166. )
  167. self.stages = nn.ModuleList()
  168. drop_path_rates = [
  169. x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
  170. ]
  171. cur = 0
  172. for i in range(len(depths)):
  173. stage = nn.Sequential(
  174. *[
  175. block_fn(
  176. dim=dims[i],
  177. drop_path=drop_path_rates[cur + j],
  178. layer_scale_init_value=layer_scale_init_value,
  179. )
  180. for j in range(depths[i])
  181. ]
  182. )
  183. self.stages.append(stage)
  184. cur += depths[i]
  185. self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
  186. self.apply(self._init_weights)
  187. def _init_weights(self, m):
  188. if isinstance(m, (nn.Conv1d, nn.Linear)):
  189. nn.init.trunc_normal_(m.weight, std=0.02)
  190. nn.init.constant_(m.bias, 0)
  191. def forward(
  192. self,
  193. x: torch.Tensor,
  194. ) -> torch.Tensor:
  195. for channel_layer, stage in zip(self.channel_layers, self.stages):
  196. x = channel_layer(x)
  197. x = stage(x)
  198. return self.norm(x)