wavenet.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import math
  2. import torch
  3. import torch.nn.functional as F
  4. from torch import nn
  5. class Mish(nn.Module):
  6. def forward(self, x):
  7. return x * torch.tanh(F.softplus(x))
  8. class DiffusionEmbedding(nn.Module):
  9. """Diffusion Step Embedding"""
  10. def __init__(self, d_denoiser):
  11. super(DiffusionEmbedding, self).__init__()
  12. self.dim = d_denoiser
  13. def forward(self, x):
  14. device = x.device
  15. half_dim = self.dim // 2
  16. emb = math.log(10000) / (half_dim - 1)
  17. emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
  18. emb = x[:, None] * emb[None, :]
  19. emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
  20. return emb
  21. class LinearNorm(nn.Module):
  22. """LinearNorm Projection"""
  23. def __init__(self, in_features, out_features, bias=False):
  24. super(LinearNorm, self).__init__()
  25. self.linear = nn.Linear(in_features, out_features, bias)
  26. nn.init.xavier_uniform_(self.linear.weight)
  27. if bias:
  28. nn.init.constant_(self.linear.bias, 0.0)
  29. def forward(self, x):
  30. x = self.linear(x)
  31. return x
  32. class ConvNorm(nn.Module):
  33. """1D Convolution"""
  34. def __init__(
  35. self,
  36. in_channels,
  37. out_channels,
  38. kernel_size=1,
  39. stride=1,
  40. padding=None,
  41. dilation=1,
  42. bias=True,
  43. w_init_gain="linear",
  44. ):
  45. super(ConvNorm, self).__init__()
  46. if padding is None:
  47. assert kernel_size % 2 == 1
  48. padding = int(dilation * (kernel_size - 1) / 2)
  49. self.conv = nn.Conv1d(
  50. in_channels,
  51. out_channels,
  52. kernel_size=kernel_size,
  53. stride=stride,
  54. padding=padding,
  55. dilation=dilation,
  56. bias=bias,
  57. )
  58. nn.init.kaiming_normal_(self.conv.weight)
  59. def forward(self, signal):
  60. conv_signal = self.conv(signal)
  61. return conv_signal
  62. class ResidualBlock(nn.Module):
  63. """Residual Block"""
  64. def __init__(self, d_encoder, residual_channels, use_linear_bias=False, dilation=1):
  65. super(ResidualBlock, self).__init__()
  66. self.conv_layer = ConvNorm(
  67. residual_channels,
  68. 2 * residual_channels,
  69. kernel_size=3,
  70. stride=1,
  71. padding=dilation,
  72. dilation=dilation,
  73. )
  74. self.diffusion_projection = LinearNorm(
  75. residual_channels, residual_channels, use_linear_bias
  76. )
  77. self.conditioner_projection = ConvNorm(
  78. d_encoder, 2 * residual_channels, kernel_size=1
  79. )
  80. self.output_projection = ConvNorm(
  81. residual_channels, 2 * residual_channels, kernel_size=1
  82. )
  83. def forward(self, x, conditioner, diffusion_step):
  84. diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
  85. conditioner = self.conditioner_projection(conditioner)
  86. y = x + diffusion_step
  87. y = self.conv_layer(y) + conditioner
  88. gate, filter = torch.chunk(y, 2, dim=1)
  89. y = torch.sigmoid(gate) * torch.tanh(filter)
  90. y = self.output_projection(y)
  91. residual, skip = torch.chunk(y, 2, dim=1)
  92. return (x + residual) / math.sqrt(2.0), skip
  93. class WaveNet(nn.Module):
  94. """
  95. WaveNet
  96. https://www.deepmind.com/blog/wavenet-a-generative-model-for-raw-audio
  97. """
  98. def __init__(
  99. self,
  100. in_channels=128,
  101. out_channels=128,
  102. d_encoder=256,
  103. residual_channels=512,
  104. residual_layers=20,
  105. use_linear_bias=False,
  106. dilation_cycle=None,
  107. ):
  108. super(WaveNet, self).__init__()
  109. self.input_projection = ConvNorm(in_channels, residual_channels, kernel_size=1)
  110. self.diffusion_embedding = DiffusionEmbedding(residual_channels)
  111. self.mlp = nn.Sequential(
  112. LinearNorm(residual_channels, residual_channels * 4, use_linear_bias),
  113. Mish(),
  114. LinearNorm(residual_channels * 4, residual_channels, use_linear_bias),
  115. )
  116. self.residual_layers = nn.ModuleList(
  117. [
  118. ResidualBlock(
  119. d_encoder,
  120. residual_channels,
  121. use_linear_bias=use_linear_bias,
  122. dilation=2 ** (i % dilation_cycle) if dilation_cycle else 1,
  123. )
  124. for i in range(residual_layers)
  125. ]
  126. )
  127. self.skip_projection = ConvNorm(
  128. residual_channels, residual_channels, kernel_size=1
  129. )
  130. self.output_projection = ConvNorm(
  131. residual_channels, out_channels, kernel_size=1
  132. )
  133. nn.init.zeros_(self.output_projection.conv.weight)
  134. def forward(self, x, diffusion_step, x_masks, condition):
  135. x = self.input_projection(x) # x [B, residual_channel, T]
  136. x = F.relu(x)
  137. diffusion_step = self.diffusion_embedding(diffusion_step)
  138. diffusion_step = self.mlp(diffusion_step)
  139. skip = []
  140. for layer in self.residual_layers:
  141. x, skip_connection = layer(x * x_masks, condition * x_masks, diffusion_step)
  142. skip.append(skip_connection)
  143. x = torch.sum(torch.stack(skip), dim=0) / math.sqrt(len(self.residual_layers))
  144. x = self.skip_projection(x)
  145. x = F.relu(x)
  146. x = self.output_projection(x) # [B, 128, T]
  147. x = x * x_masks
  148. return x