convnext_1d.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. from dataclasses import dataclass
  2. from typing import Optional, Tuple, Union
  3. import torch
  4. import torch.nn as nn
  5. from diffusers.configuration_utils import ConfigMixin, register_to_config
  6. from diffusers.models.embeddings import (
  7. GaussianFourierProjection,
  8. TimestepEmbedding,
  9. Timesteps,
  10. )
  11. from diffusers.models.modeling_utils import ModelMixin
  12. from diffusers.utils import BaseOutput
  13. class ConvNeXtBlock(nn.Module):
  14. """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
  15. Args:
  16. dim (int): Number of input channels.
  17. mlp_dim (int): Dimensionality of the intermediate layer.
  18. layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
  19. Defaults to None.
  20. adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
  21. None means non-conditional LayerNorm. Defaults to None.
  22. """
  23. def __init__(
  24. self,
  25. dim: int,
  26. intermediate_dim: int,
  27. dilation: int = 1,
  28. layer_scale_init_value: Optional[float] = 1e-6,
  29. ):
  30. super().__init__()
  31. self.dwconv = nn.Conv1d(
  32. dim,
  33. dim,
  34. kernel_size=7,
  35. groups=dim,
  36. dilation=dilation,
  37. padding=int(dilation * (7 - 1) / 2),
  38. ) # depthwise conv
  39. self.norm = nn.LayerNorm(dim, eps=1e-6)
  40. self.pwconv1 = nn.Linear(
  41. dim, intermediate_dim
  42. ) # pointwise/1x1 convs, implemented with linear layers
  43. self.act = nn.GELU()
  44. self.pwconv2 = nn.Linear(intermediate_dim, dim)
  45. self.gamma = (
  46. nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
  47. if layer_scale_init_value is not None and layer_scale_init_value > 0
  48. else None
  49. )
  50. self.condition_projection = nn.Sequential(
  51. nn.Conv1d(dim, dim, 1),
  52. nn.GELU(),
  53. nn.Conv1d(dim, dim, 1),
  54. )
  55. def forward(
  56. self,
  57. x: torch.Tensor,
  58. condition: Optional[torch.Tensor] = None,
  59. x_mask: Optional[torch.Tensor] = None,
  60. ) -> torch.Tensor:
  61. residual = x
  62. if condition is not None:
  63. x = x + self.condition_projection(condition)
  64. if x_mask is not None:
  65. x = x * x_mask
  66. x = self.dwconv(x)
  67. x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
  68. x = self.norm(x)
  69. x = self.pwconv1(x)
  70. x = self.act(x)
  71. x = self.pwconv2(x)
  72. if self.gamma is not None:
  73. x = self.gamma * x
  74. x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
  75. x = residual + x
  76. return x
  77. class ConvNext1DModel(ModelMixin, ConfigMixin):
  78. r"""
  79. A ConvNext model that takes a noisy sample and a timestep and returns a sample shaped output.
  80. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
  81. for all models (such as downloading or saving).
  82. Parameters:
  83. in_channels (`int`, *optional*, defaults to 128):
  84. Number of channels in the input sample.
  85. out_channels (`int`, *optional*, defaults to 128):
  86. Number of channels in the output.
  87. intermediate_dim (`int`, *optional*, defaults to 512):
  88. Dimensionality of the intermediate blocks.
  89. mlp_dim (`int`, *optional*, defaults to 2048):
  90. Dimensionality of the MLP.
  91. num_layers (`int`, *optional*, defaults to 20):
  92. Number of layers in the model.
  93. dilation_cycle_length (`int`, *optional*, defaults to 4):
  94. Length of the dilation cycle.
  95. time_embedding_type (`str`, *optional*, defaults to `positional`):
  96. The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
  97. time_embedding_dim (`int`, *optional*, defaults to `None`):
  98. An optional override for the dimension of the projected time embedding.
  99. time_embedding_act_fn (`str`, *optional*, defaults to `None`):
  100. Optional activation function to use only once on the time embeddings before they are passed to the rest of
  101. the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
  102. """
  103. _supports_gradient_checkpointing = True
  104. @register_to_config
  105. def __init__(
  106. self,
  107. in_channels: int = 128,
  108. out_channels: int = 128,
  109. intermediate_dim: int = 512,
  110. mlp_dim: int = 2048,
  111. num_layers: int = 20,
  112. dilation_cycle_length: int = 4,
  113. time_embedding_type: str = "positional",
  114. ):
  115. super().__init__()
  116. if intermediate_dim % 2 != 0:
  117. raise ValueError("intermediate_dim must be divisible by 2.")
  118. # time
  119. if time_embedding_type == "fourier":
  120. self.time_proj = GaussianFourierProjection(
  121. intermediate_dim // 2,
  122. set_W_to_weight=False,
  123. log=False,
  124. flip_sin_to_cos=False,
  125. )
  126. timestep_input_dim = intermediate_dim
  127. elif time_embedding_type == "positional":
  128. self.time_proj = Timesteps(in_channels, False, 0)
  129. timestep_input_dim = in_channels
  130. else:
  131. raise ValueError(
  132. f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
  133. )
  134. self.time_mlp = TimestepEmbedding(
  135. timestep_input_dim,
  136. intermediate_dim,
  137. act_fn="silu",
  138. cond_proj_dim=None, # No conditional projection for now
  139. )
  140. # Project to intermediate dim
  141. self.in_proj = nn.Conv1d(in_channels, intermediate_dim, 1)
  142. self.out_proj = nn.Conv1d(intermediate_dim, out_channels, 1)
  143. # Initialize weights
  144. nn.init.normal_(self.out_proj.weight, mean=0, std=0.01)
  145. nn.init.zeros_(self.out_proj.bias)
  146. # Blocks
  147. self.blocks = nn.ModuleList(
  148. [
  149. ConvNeXtBlock(
  150. dim=intermediate_dim,
  151. intermediate_dim=mlp_dim,
  152. dilation=2 ** (i % dilation_cycle_length),
  153. )
  154. for i in range(num_layers)
  155. ]
  156. )
  157. self.gradient_checkpointing = False
  158. def _set_gradient_checkpointing(self, module, value: bool = False):
  159. self.gradient_checkpointing = value
  160. def forward(
  161. self,
  162. sample: torch.FloatTensor,
  163. timestep: Union[torch.Tensor, float, int],
  164. sample_mask: Optional[torch.Tensor] = None,
  165. condition: Optional[torch.Tensor] = None,
  166. ) -> torch.FloatTensor:
  167. r"""
  168. The [`ConvNext1DModel`] forward method.
  169. Args:
  170. sample (`torch.FloatTensor`):
  171. The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`.
  172. timestep (`torch.FloatTensor` or `float` or `int`):
  173. The number of timesteps to denoise an input.
  174. sample_mask (`torch.BoolTensor`, *optional*):
  175. A mask of the same shape as `sample` that indicates which elements are invalid.
  176. True means the element is invalid and should be masked out.
  177. return_dict (`bool`, *optional*, defaults to `True`):
  178. Whether or not to return a [`~models.unet_1d.ConvNext1DOutput`] instead of a plain tuple.
  179. Returns:
  180. [`~models.unet_1d.ConvNext1DOutput`] or `tuple`:
  181. If `return_dict` is True, an [`~models.unet_1d.ConvNext1DOutput`] is returned, otherwise a `tuple` is
  182. returned where the first element is the sample tensor.
  183. """
  184. # 1. time
  185. t_emb = self.time_proj(timestep)
  186. t_emb = self.time_mlp(t_emb)[..., None]
  187. # 2. pre-process
  188. if condition is not None:
  189. sample = torch.cat([sample, condition], dim=1)
  190. x = self.in_proj(sample)
  191. if sample_mask.ndim == 2:
  192. sample_mask = sample_mask[:, None, :]
  193. # 3. blocks
  194. for block in self.blocks:
  195. if self.training and self.is_gradient_checkpointing:
  196. x = torch.utils.checkpoint.checkpoint(block, x, t_emb, sample_mask)
  197. else:
  198. x = block(x, t_emb, sample_mask)
  199. # 4. post-process
  200. return self.out_proj(x)