wavenet.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import math
  2. from typing import Optional
  3. import torch
  4. import torch.nn.functional as F
  5. from torch import nn
  6. class Mish(nn.Module):
  7. def forward(self, x):
  8. return x * torch.tanh(F.softplus(x))
  9. class DiffusionEmbedding(nn.Module):
  10. """Diffusion Step Embedding"""
  11. def __init__(self, d_denoiser):
  12. super(DiffusionEmbedding, self).__init__()
  13. self.dim = d_denoiser
  14. def forward(self, x):
  15. device = x.device
  16. half_dim = self.dim // 2
  17. emb = math.log(10000) / (half_dim - 1)
  18. emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
  19. emb = x[:, None] * emb[None, :]
  20. emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
  21. return emb
  22. class LinearNorm(nn.Module):
  23. """LinearNorm Projection"""
  24. def __init__(self, in_features, out_features, bias=False):
  25. super(LinearNorm, self).__init__()
  26. self.linear = nn.Linear(in_features, out_features, bias)
  27. nn.init.xavier_uniform_(self.linear.weight)
  28. if bias:
  29. nn.init.constant_(self.linear.bias, 0.0)
  30. def forward(self, x):
  31. x = self.linear(x)
  32. return x
  33. class ConvNorm(nn.Module):
  34. """1D Convolution"""
  35. def __init__(
  36. self,
  37. in_channels,
  38. out_channels,
  39. kernel_size=1,
  40. stride=1,
  41. padding=None,
  42. dilation=1,
  43. bias=True,
  44. w_init_gain="linear",
  45. ):
  46. super(ConvNorm, self).__init__()
  47. if padding is None:
  48. assert kernel_size % 2 == 1
  49. padding = int(dilation * (kernel_size - 1) / 2)
  50. self.conv = nn.Conv1d(
  51. in_channels,
  52. out_channels,
  53. kernel_size=kernel_size,
  54. stride=stride,
  55. padding=padding,
  56. dilation=dilation,
  57. bias=bias,
  58. )
  59. nn.init.kaiming_normal_(self.conv.weight)
  60. def forward(self, signal):
  61. conv_signal = self.conv(signal)
  62. return conv_signal
  63. class ResidualBlock(nn.Module):
  64. """Residual Block"""
  65. def __init__(
  66. self,
  67. residual_channels,
  68. use_linear_bias=False,
  69. dilation=1,
  70. condition_channels=None,
  71. ):
  72. super(ResidualBlock, self).__init__()
  73. self.conv_layer = ConvNorm(
  74. residual_channels,
  75. 2 * residual_channels,
  76. kernel_size=3,
  77. stride=1,
  78. padding=dilation,
  79. dilation=dilation,
  80. )
  81. if condition_channels is not None:
  82. self.diffusion_projection = LinearNorm(
  83. residual_channels, residual_channels, use_linear_bias
  84. )
  85. self.condition_projection = ConvNorm(
  86. condition_channels, 2 * residual_channels, kernel_size=1
  87. )
  88. self.output_projection = ConvNorm(
  89. residual_channels, 2 * residual_channels, kernel_size=1
  90. )
  91. def forward(self, x, condition=None, diffusion_step=None):
  92. y = x
  93. if diffusion_step is not None:
  94. diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
  95. y = y + diffusion_step
  96. y = self.conv_layer(y)
  97. if condition is not None:
  98. condition = self.condition_projection(condition)
  99. y = y + condition
  100. gate, filter = torch.chunk(y, 2, dim=1)
  101. y = torch.sigmoid(gate) * torch.tanh(filter)
  102. y = self.output_projection(y)
  103. residual, skip = torch.chunk(y, 2, dim=1)
  104. return (x + residual) / math.sqrt(2.0), skip
  105. class WaveNet(nn.Module):
  106. def __init__(
  107. self,
  108. input_channels: Optional[int] = None,
  109. output_channels: Optional[int] = None,
  110. residual_channels: int = 512,
  111. residual_layers: int = 20,
  112. dilation_cycle: Optional[int] = 4,
  113. is_diffusion: bool = False,
  114. condition_channels: Optional[int] = None,
  115. ):
  116. super().__init__()
  117. # Input projection
  118. self.input_projection = None
  119. if input_channels is not None and input_channels != residual_channels:
  120. self.input_projection = ConvNorm(
  121. input_channels, residual_channels, kernel_size=1
  122. )
  123. if input_channels is None:
  124. input_channels = residual_channels
  125. self.input_channels = input_channels
  126. # Residual layers
  127. self.residual_layers = nn.ModuleList(
  128. [
  129. ResidualBlock(
  130. residual_channels=residual_channels,
  131. use_linear_bias=False,
  132. dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
  133. condition_channels=condition_channels,
  134. )
  135. for i in range(residual_layers)
  136. ]
  137. )
  138. # Skip projection
  139. self.skip_projection = ConvNorm(
  140. residual_channels, residual_channels, kernel_size=1
  141. )
  142. # Output projection
  143. self.output_projection = None
  144. if output_channels is not None and output_channels != residual_channels:
  145. self.output_projection = ConvNorm(
  146. residual_channels, output_channels, kernel_size=1
  147. )
  148. if is_diffusion:
  149. self.diffusion_embedding = DiffusionEmbedding(residual_channels)
  150. self.mlp = nn.Sequential(
  151. LinearNorm(residual_channels, residual_channels * 4, False),
  152. Mish(),
  153. LinearNorm(residual_channels * 4, residual_channels, False),
  154. )
  155. self.apply(self._init_weights)
  156. def _init_weights(self, m):
  157. if isinstance(m, (nn.Conv1d, nn.Linear)):
  158. nn.init.trunc_normal_(m.weight, std=0.02)
  159. if getattr(m, "bias", None) is not None:
  160. nn.init.constant_(m.bias, 0)
  161. def forward(self, x, t=None, condition=None):
  162. if self.input_projection is not None:
  163. x = self.input_projection(x)
  164. x = F.silu(x)
  165. if t is not None:
  166. t = self.diffusion_embedding(t)
  167. t = self.mlp(t)
  168. skip = []
  169. for layer in self.residual_layers:
  170. x, skip_connection = layer(x, condition, t)
  171. skip.append(skip_connection)
  172. x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
  173. x = self.skip_projection(x)
  174. if self.output_projection is not None:
  175. x = F.silu(x)
  176. x = self.output_projection(x)
  177. return x