convnext_1d.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. from dataclasses import dataclass
  2. from typing import Optional, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. from diffusers.configuration_utils import ConfigMixin, register_to_config
  6. from diffusers.models.embeddings import (
  7. GaussianFourierProjection,
  8. TimestepEmbedding,
  9. Timesteps,
  10. )
  11. from diffusers.models.modeling_utils import ModelMixin
  12. from diffusers.utils import BaseOutput
  13. class ConvNeXtBlock(nn.Module):
  14. """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
  15. Args:
  16. dim (int): Number of input channels.
  17. mlp_dim (int): Dimensionality of the intermediate layer.
  18. layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
  19. Defaults to None.
  20. adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
  21. None means non-conditional LayerNorm. Defaults to None.
  22. """
  23. def __init__(
  24. self,
  25. dim: int,
  26. intermediate_dim: int,
  27. dilation: int = 1,
  28. layer_scale_init_value: Optional[float] = 1e-6,
  29. ):
  30. super().__init__()
  31. self.dwconv = nn.Conv1d(
  32. dim,
  33. dim,
  34. kernel_size=7,
  35. groups=dim,
  36. dilation=dilation,
  37. padding=int(dilation * (7 - 1) / 2),
  38. ) # depthwise conv
  39. self.norm = nn.LayerNorm(dim, eps=1e-6)
  40. self.pwconv1 = nn.Linear(
  41. dim, intermediate_dim
  42. ) # pointwise/1x1 convs, implemented with linear layers
  43. self.act = nn.GELU()
  44. self.pwconv2 = nn.Linear(intermediate_dim, dim)
  45. self.gamma = (
  46. nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
  47. if layer_scale_init_value is not None and layer_scale_init_value > 0
  48. else None
  49. )
  50. self.condition_projection = nn.Sequential(
  51. nn.Conv1d(dim, dim, 1),
  52. nn.GELU(),
  53. nn.Conv1d(dim, dim, 1),
  54. )
  55. def forward(
  56. self,
  57. x: torch.Tensor,
  58. condition: Optional[torch.Tensor] = None,
  59. x_mask: Optional[torch.Tensor] = None,
  60. ) -> torch.Tensor:
  61. residual = x
  62. if condition is not None:
  63. x = x + self.condition_projection(condition)
  64. if x_mask is not None:
  65. x = x * x_mask
  66. x = self.dwconv(x)
  67. x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
  68. x = self.norm(x)
  69. x = self.pwconv1(x)
  70. x = self.act(x)
  71. x = self.pwconv2(x)
  72. if self.gamma is not None:
  73. x = self.gamma * x
  74. x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
  75. x = residual + x
  76. return x
  77. @dataclass
  78. class ConvNext1DOutput(BaseOutput):
  79. """
  80. The output of [`UNet1DModel`].
  81. Args:
  82. sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`):
  83. The hidden states output from the last layer of the model.
  84. """
  85. sample: torch.FloatTensor
  86. class ConvNext1DModel(ModelMixin, ConfigMixin):
  87. r"""
  88. A ConvNext model that takes a noisy sample and a timestep and returns a sample shaped output.
  89. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
  90. for all models (such as downloading or saving).
  91. Parameters:
  92. in_channels (`int`, *optional*, defaults to 128):
  93. Number of channels in the input sample.
  94. out_channels (`int`, *optional*, defaults to 128):
  95. Number of channels in the output.
  96. intermediate_dim (`int`, *optional*, defaults to 512):
  97. Dimensionality of the intermediate blocks.
  98. mlp_dim (`int`, *optional*, defaults to 2048):
  99. Dimensionality of the MLP.
  100. num_layers (`int`, *optional*, defaults to 20):
  101. Number of layers in the model.
  102. dilation_cycle_length (`int`, *optional*, defaults to 4):
  103. Length of the dilation cycle.
  104. time_embedding_type (`str`, *optional*, defaults to `positional`):
  105. The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
  106. time_embedding_dim (`int`, *optional*, defaults to `None`):
  107. An optional override for the dimension of the projected time embedding.
  108. time_embedding_act_fn (`str`, *optional*, defaults to `None`):
  109. Optional activation function to use only once on the time embeddings before they are passed to the rest of
  110. the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
  111. """
  112. _supports_gradient_checkpointing = True
  113. @register_to_config
  114. def __init__(
  115. self,
  116. in_channels: int = 128,
  117. out_channels: int = 128,
  118. intermediate_dim: int = 512,
  119. mlp_dim: int = 2048,
  120. num_layers: int = 20,
  121. dilation_cycle_length: int = 4,
  122. time_embedding_type: str = "positional",
  123. ):
  124. super().__init__()
  125. if intermediate_dim % 2 != 0:
  126. raise ValueError("intermediate_dim must be divisible by 2.")
  127. # time
  128. if time_embedding_type == "fourier":
  129. self.time_proj = GaussianFourierProjection(
  130. intermediate_dim // 2,
  131. set_W_to_weight=False,
  132. log=False,
  133. flip_sin_to_cos=False,
  134. )
  135. timestep_input_dim = intermediate_dim
  136. elif time_embedding_type == "positional":
  137. self.time_proj = Timesteps(in_channels, False, 0)
  138. timestep_input_dim = in_channels
  139. else:
  140. raise ValueError(
  141. f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
  142. )
  143. self.time_mlp = TimestepEmbedding(
  144. timestep_input_dim,
  145. intermediate_dim,
  146. act_fn="silu",
  147. cond_proj_dim=None, # No conditional projection for now
  148. )
  149. # Project to intermediate dim
  150. self.in_proj = nn.Conv1d(in_channels, intermediate_dim, 1)
  151. self.out_proj = nn.Conv1d(intermediate_dim, out_channels, 1)
  152. # Blocks
  153. self.blocks = nn.ModuleList(
  154. [
  155. ConvNeXtBlock(
  156. dim=intermediate_dim,
  157. intermediate_dim=mlp_dim,
  158. dilation=2 ** (i % dilation_cycle_length),
  159. )
  160. for i in range(num_layers)
  161. ]
  162. )
  163. self.gradient_checkpointing = False
  164. def _set_gradient_checkpointing(self, module, value: bool = False):
  165. self.gradient_checkpointing = value
  166. def forward(
  167. self,
  168. sample: torch.FloatTensor,
  169. timestep: Union[torch.Tensor, float, int],
  170. sample_mask: Optional[torch.Tensor] = None,
  171. condition: Optional[torch.Tensor] = None,
  172. ) -> Union[ConvNext1DOutput, Tuple]:
  173. r"""
  174. The [`ConvNext1DModel`] forward method.
  175. Args:
  176. sample (`torch.FloatTensor`):
  177. The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
  178. timestep (`torch.FloatTensor` or `float` or `int`):
  179. The number of timesteps to denoise an input.
  180. sample_mask (`torch.BoolTensor`, *optional*):
  181. A mask of the same shape as `sample` that indicates which elements are invalid.
  182. True means the element is invalid and should be masked out.
  183. return_dict (`bool`, *optional*, defaults to `True`):
  184. Whether or not to return a [`~models.unet_1d.ConvNext1DOutput`] instead of a plain tuple.
  185. Returns:
  186. [`~models.unet_1d.ConvNext1DOutput`] or `tuple`:
  187. If `return_dict` is True, an [`~models.unet_1d.ConvNext1DOutput`] is returned, otherwise a `tuple` is
  188. returned where the first element is the sample tensor.
  189. """
  190. # 1. time
  191. t_emb = self.time_proj(timestep)
  192. t_emb = self.time_mlp(t_emb)[..., None]
  193. # 2. pre-process
  194. if condition is not None:
  195. sample = torch.cat([sample, condition], dim=1)
  196. x = self.in_proj(sample)
  197. if sample_mask.ndim == 2:
  198. sample_mask = sample_mask[:, None, :]
  199. # 3. blocks
  200. for block in self.blocks:
  201. if self.training and self.is_gradient_checkpointing:
  202. x = torch.utils.checkpoint.checkpoint(block, x, t_emb, sample_mask)
  203. else:
  204. x = block(x, t_emb, sample_mask)
  205. # 4. post-process
  206. return self.out_proj(x)