convnext_1d.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. from dataclasses import dataclass
  2. from typing import Optional, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. from diffusers.configuration_utils import ConfigMixin, register_to_config
  6. from diffusers.models.embeddings import (
  7. GaussianFourierProjection,
  8. TimestepEmbedding,
  9. Timesteps,
  10. )
  11. from diffusers.models.modeling_utils import ModelMixin
  12. from diffusers.utils import BaseOutput
  13. class ConvNeXtBlock(nn.Module):
  14. """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
  15. Args:
  16. dim (int): Number of input channels.
  17. mlp_dim (int): Dimensionality of the intermediate layer.
  18. layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
  19. Defaults to None.
  20. adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
  21. None means non-conditional LayerNorm. Defaults to None.
  22. """
  23. def __init__(
  24. self,
  25. dim: int,
  26. intermediate_dim: int,
  27. dilation: int = 1,
  28. layer_scale_init_value: Optional[float] = 1e-6,
  29. ):
  30. super().__init__()
  31. self.dwconv = nn.Conv1d(
  32. dim,
  33. dim,
  34. kernel_size=7,
  35. groups=dim,
  36. dilation=dilation,
  37. padding=int(dilation * (7 - 1) / 2),
  38. ) # depthwise conv
  39. self.norm = nn.LayerNorm(dim, eps=1e-6)
  40. self.pwconv1 = nn.Linear(
  41. dim, intermediate_dim
  42. ) # pointwise/1x1 convs, implemented with linear layers
  43. self.act = nn.GELU()
  44. self.pwconv2 = nn.Linear(intermediate_dim, dim)
  45. self.gamma = (
  46. nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
  47. if layer_scale_init_value is not None and layer_scale_init_value > 0
  48. else None
  49. )
  50. self.condition_projection = nn.Sequential(
  51. nn.Conv1d(dim, dim, 1),
  52. nn.GELU(),
  53. nn.Conv1d(dim, dim, 1),
  54. )
  55. def forward(
  56. self,
  57. x: torch.Tensor,
  58. condition: Optional[torch.Tensor] = None,
  59. x_mask: Optional[torch.Tensor] = None,
  60. ) -> torch.Tensor:
  61. residual = x
  62. if condition is not None:
  63. x = x + self.condition_projection(condition)
  64. if x_mask is not None:
  65. x = x * x_mask
  66. x = self.dwconv(x)
  67. x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
  68. x = self.norm(x)
  69. x = self.pwconv1(x)
  70. x = self.act(x)
  71. x = self.pwconv2(x)
  72. if self.gamma is not None:
  73. x = self.gamma * x
  74. x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
  75. x = residual + x
  76. return x
  77. class ConvNext1DModel(ModelMixin, ConfigMixin):
  78. r"""
  79. A ConvNext model that takes a noisy sample and a timestep and returns a sample shaped output.
  80. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
  81. for all models (such as downloading or saving).
  82. Parameters:
  83. in_channels (`int`, *optional*, defaults to 128):
  84. Number of channels in the input sample.
  85. out_channels (`int`, *optional*, defaults to 128):
  86. Number of channels in the output.
  87. intermediate_dim (`int`, *optional*, defaults to 512):
  88. Dimensionality of the intermediate blocks.
  89. mlp_dim (`int`, *optional*, defaults to 2048):
  90. Dimensionality of the MLP.
  91. num_layers (`int`, *optional*, defaults to 20):
  92. Number of layers in the model.
  93. dilation_cycle_length (`int`, *optional*, defaults to 4):
  94. Length of the dilation cycle.
  95. time_embedding_type (`str`, *optional*, defaults to `positional`):
  96. The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
  97. time_embedding_dim (`int`, *optional*, defaults to `None`):
  98. An optional override for the dimension of the projected time embedding.
  99. time_embedding_act_fn (`str`, *optional*, defaults to `None`):
  100. Optional activation function to use only once on the time embeddings before they are passed to the rest of
  101. the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
  102. """
  103. _supports_gradient_checkpointing = True
  104. @register_to_config
  105. def __init__(
  106. self,
  107. in_channels: int = 128,
  108. out_channels: int = 128,
  109. intermediate_dim: int = 512,
  110. mlp_dim: int = 2048,
  111. num_layers: int = 20,
  112. dilation_cycle_length: int = 4,
  113. time_embedding_type: str = "positional",
  114. ):
  115. super().__init__()
  116. if intermediate_dim % 2 != 0:
  117. raise ValueError("intermediate_dim must be divisible by 2.")
  118. # time
  119. if time_embedding_type == "fourier":
  120. self.time_proj = GaussianFourierProjection(
  121. intermediate_dim // 2,
  122. set_W_to_weight=False,
  123. log=False,
  124. flip_sin_to_cos=False,
  125. )
  126. timestep_input_dim = intermediate_dim
  127. elif time_embedding_type == "positional":
  128. self.time_proj = Timesteps(in_channels, False, 0)
  129. timestep_input_dim = in_channels
  130. else:
  131. raise ValueError(
  132. f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
  133. )
  134. self.time_mlp = TimestepEmbedding(
  135. timestep_input_dim,
  136. intermediate_dim,
  137. act_fn="silu",
  138. cond_proj_dim=None, # No conditional projection for now
  139. )
  140. # Project to intermediate dim
  141. self.in_proj = nn.Conv1d(in_channels, intermediate_dim, 1)
  142. self.out_proj = nn.Conv1d(intermediate_dim, out_channels, 1)
  143. # Blocks
  144. self.blocks = nn.ModuleList(
  145. [
  146. ConvNeXtBlock(
  147. dim=intermediate_dim,
  148. intermediate_dim=mlp_dim,
  149. dilation=2 ** (i % dilation_cycle_length),
  150. )
  151. for i in range(num_layers)
  152. ]
  153. )
  154. # Initialize weights
  155. self.apply(self._init_weights)
  156. self.gradient_checkpointing = False
  157. def _set_gradient_checkpointing(self, module, value: bool = False):
  158. self.gradient_checkpointing = value
  159. def _init_weights(self, m):
  160. if isinstance(m, (nn.Conv2d, nn.Linear, nn.Conv1d)):
  161. nn.init.trunc_normal_(m.weight, mean=0, std=0.02)
  162. if m.bias is not None:
  163. nn.init.zeros_(m.bias)
  164. def forward(
  165. self,
  166. sample: torch.FloatTensor,
  167. timestep: Union[torch.Tensor, float, int],
  168. sample_mask: Optional[torch.Tensor] = None,
  169. condition: Optional[torch.Tensor] = None,
  170. ) -> torch.FloatTensor:
  171. r"""
  172. The [`ConvNext1DModel`] forward method.
  173. Args:
  174. sample (`torch.FloatTensor`):
  175. The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
  176. timestep (`torch.FloatTensor` or `float` or `int`):
  177. The number of timesteps to denoise an input.
  178. sample_mask (`torch.BoolTensor`, *optional*):
  179. A mask of the same shape as `sample` that indicates which elements are invalid.
  180. True means the element is invalid and should be masked out.
  181. return_dict (`bool`, *optional*, defaults to `True`):
  182. Whether or not to return a [`~models.unet_1d.ConvNext1DOutput`] instead of a plain tuple.
  183. Returns:
  184. [`~models.unet_1d.ConvNext1DOutput`] or `tuple`:
  185. If `return_dict` is True, an [`~models.unet_1d.ConvNext1DOutput`] is returned, otherwise a `tuple` is
  186. returned where the first element is the sample tensor.
  187. """
  188. # 1. time
  189. t_emb = self.time_proj(timestep)
  190. t_emb = self.time_mlp(t_emb)[..., None]
  191. # 2. pre-process
  192. if condition is not None:
  193. sample = torch.cat([sample, condition], dim=1)
  194. x = self.in_proj(sample)
  195. if sample_mask.ndim == 2:
  196. sample_mask = sample_mask[:, None, :]
  197. # 3. blocks
  198. for block in self.blocks:
  199. if self.training and self.is_gradient_checkpointing:
  200. x = torch.utils.checkpoint.checkpoint(block, x, t_emb, sample_mask)
  201. else:
  202. x = block(x, t_emb, sample_mask)
  203. # 4. post-process
  204. return self.out_proj(x)