wavenet.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. import math
  2. from typing import Optional, Union
  3. import torch
  4. import torch.nn.functional as F
  5. from torch import nn
  6. class Mish(nn.Module):
  7. def forward(self, x):
  8. return x * torch.tanh(F.softplus(x))
  9. class DiffusionEmbedding(nn.Module):
  10. """Diffusion Step Embedding"""
  11. def __init__(self, d_denoiser):
  12. super(DiffusionEmbedding, self).__init__()
  13. self.dim = d_denoiser
  14. def forward(self, x):
  15. device = x.device
  16. half_dim = self.dim // 2
  17. emb = math.log(10000) / (half_dim - 1)
  18. emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
  19. emb = x[:, None] * emb[None, :]
  20. emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
  21. return emb
  22. class LinearNorm(nn.Module):
  23. """LinearNorm Projection"""
  24. def __init__(self, in_features, out_features, bias=False):
  25. super(LinearNorm, self).__init__()
  26. self.linear = nn.Linear(in_features, out_features, bias)
  27. nn.init.xavier_uniform_(self.linear.weight)
  28. if bias:
  29. nn.init.constant_(self.linear.bias, 0.0)
  30. def forward(self, x):
  31. x = self.linear(x)
  32. return x
  33. class ConvNorm(nn.Module):
  34. """1D Convolution"""
  35. def __init__(
  36. self,
  37. in_channels,
  38. out_channels,
  39. kernel_size=1,
  40. stride=1,
  41. padding=None,
  42. dilation=1,
  43. bias=True,
  44. w_init_gain="linear",
  45. ):
  46. super(ConvNorm, self).__init__()
  47. if padding is None:
  48. assert kernel_size % 2 == 1
  49. padding = int(dilation * (kernel_size - 1) / 2)
  50. self.conv = nn.Conv1d(
  51. in_channels,
  52. out_channels,
  53. kernel_size=kernel_size,
  54. stride=stride,
  55. padding=padding,
  56. dilation=dilation,
  57. bias=bias,
  58. )
  59. nn.init.kaiming_normal_(self.conv.weight)
  60. def forward(self, signal):
  61. conv_signal = self.conv(signal)
  62. return conv_signal
  63. class ResidualBlock(nn.Module):
  64. """Residual Block"""
  65. def __init__(self, d_encoder, residual_channels, use_linear_bias=False, dilation=1):
  66. super(ResidualBlock, self).__init__()
  67. self.conv_layer = ConvNorm(
  68. residual_channels,
  69. 2 * residual_channels,
  70. kernel_size=3,
  71. stride=1,
  72. padding=dilation,
  73. dilation=dilation,
  74. )
  75. self.diffusion_projection = LinearNorm(
  76. residual_channels, residual_channels, use_linear_bias
  77. )
  78. self.condition_projection = ConvNorm(
  79. d_encoder, 2 * residual_channels, kernel_size=1
  80. )
  81. self.output_projection = ConvNorm(
  82. residual_channels, 2 * residual_channels, kernel_size=1
  83. )
  84. def forward(self, x, conditioner, diffusion_step):
  85. diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
  86. conditioner = self.condition_projection(conditioner)
  87. y = x + diffusion_step
  88. y = self.conv_layer(y) + conditioner
  89. gate, filter = torch.chunk(y, 2, dim=1)
  90. y = torch.sigmoid(gate) * torch.tanh(filter)
  91. y = self.output_projection(y)
  92. residual, skip = torch.chunk(y, 2, dim=1)
  93. return (x + residual) / math.sqrt(2.0), skip
  94. class SpectrogramUpsampler(nn.Module):
  95. def __init__(self, hop_size):
  96. super().__init__()
  97. if hop_size == 256:
  98. self.conv1 = nn.ConvTranspose2d(
  99. 1, 1, [3, 32], stride=[1, 16], padding=[1, 8]
  100. )
  101. elif hop_size == 512:
  102. self.conv1 = nn.ConvTranspose2d(
  103. 1, 1, [3, 64], stride=[1, 32], padding=[1, 16]
  104. )
  105. else:
  106. raise ValueError(f"Unsupported hop_size: {hop_size}")
  107. self.conv2 = nn.ConvTranspose2d(1, 1, [3, 32], stride=[1, 16], padding=[1, 8])
  108. def forward(self, x):
  109. x = torch.unsqueeze(x, 1)
  110. x = self.conv1(x)
  111. x = F.leaky_relu(x, 0.4)
  112. x = self.conv2(x)
  113. x = F.leaky_relu(x, 0.4)
  114. x = torch.squeeze(x, 1)
  115. return x
  116. class WaveNet(nn.Module):
  117. """
  118. WaveNet
  119. https://www.deepmind.com/blog/wavenet-a-generative-model-for-raw-audio
  120. """
  121. def __init__(
  122. self,
  123. mel_channels=128,
  124. d_encoder=256,
  125. residual_channels=512,
  126. residual_layers=20,
  127. use_linear_bias=False,
  128. dilation_cycle=None,
  129. ):
  130. super(WaveNet, self).__init__()
  131. self.input_projection = ConvNorm(mel_channels, residual_channels, kernel_size=1)
  132. self.diffusion_embedding = DiffusionEmbedding(residual_channels)
  133. self.mlp = nn.Sequential(
  134. LinearNorm(residual_channels, residual_channels * 4, use_linear_bias),
  135. Mish(),
  136. LinearNorm(residual_channels * 4, residual_channels, use_linear_bias),
  137. )
  138. self.residual_layers = nn.ModuleList(
  139. [
  140. ResidualBlock(
  141. d_encoder,
  142. residual_channels,
  143. use_linear_bias=use_linear_bias,
  144. dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
  145. )
  146. for i in range(residual_layers)
  147. ]
  148. )
  149. self.skip_projection = ConvNorm(
  150. residual_channels, residual_channels, kernel_size=1
  151. )
  152. self.output_projection = ConvNorm(
  153. residual_channels, mel_channels, kernel_size=1
  154. )
  155. nn.init.zeros_(self.output_projection.conv.weight)
  156. def forward(
  157. self,
  158. sample: torch.FloatTensor,
  159. timestep: Union[torch.Tensor, float, int],
  160. sample_mask: Optional[torch.Tensor] = None,
  161. condition: Optional[torch.Tensor] = None,
  162. ):
  163. x = self.input_projection(sample) # x [B, residual_channel, T]
  164. x = F.relu(x)
  165. diffusion_step = self.diffusion_embedding(timestep)
  166. diffusion_step = self.mlp(diffusion_step)
  167. if sample_mask is not None:
  168. if sample_mask.ndim == 2:
  169. sample_mask = sample_mask[:, None, :]
  170. x = x * sample_mask
  171. skip = []
  172. for layer in self.residual_layers:
  173. x, skip_connection = layer(x, condition, diffusion_step)
  174. skip.append(skip_connection)
  175. x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
  176. x = self.skip_projection(x)
  177. x = F.relu(x)
  178. x = self.output_projection(x) # [B, 128, T]
  179. if sample_mask is not None:
  180. x = x * sample_mask
  181. return x