modules.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. import torch
  2. import torch.nn as nn
  3. from torch.nn.utils.parametrizations import weight_norm
  4. from torch.nn.utils.parametrize import remove_parametrizations
  5. from fish_speech.models.vqgan.utils import fused_add_tanh_sigmoid_multiply
  6. LRELU_SLOPE = 0.1
  7. # ! PosteriorEncoder
  8. # ! ResidualCouplingLayer
  9. class WaveNet(nn.Module):
  10. def __init__(
  11. self,
  12. hidden_channels,
  13. kernel_size,
  14. dilation_rate,
  15. n_layers,
  16. p_dropout=0,
  17. out_channels=None,
  18. in_channels=None,
  19. ):
  20. super(WaveNet, self).__init__()
  21. assert kernel_size % 2 == 1
  22. self.hidden_channels = hidden_channels
  23. self.kernel_size = (kernel_size,)
  24. self.n_layers = n_layers
  25. self.in_layers = nn.ModuleList()
  26. self.res_skip_layers = nn.ModuleList()
  27. self.drop = nn.Dropout(p_dropout)
  28. self.in_channels = in_channels
  29. if in_channels is not None:
  30. self.proj_in = nn.Conv1d(in_channels, hidden_channels, 1)
  31. for i in range(n_layers):
  32. dilation = dilation_rate**i
  33. padding = int((kernel_size * dilation - dilation) / 2)
  34. in_layer = nn.Conv1d(
  35. hidden_channels,
  36. 2 * hidden_channels,
  37. kernel_size,
  38. dilation=dilation,
  39. padding=padding,
  40. )
  41. in_layer = weight_norm(in_layer, name="weight")
  42. self.in_layers.append(in_layer)
  43. # last one is not necessary
  44. res_skip_channels = (
  45. 2 * hidden_channels if i < n_layers - 1 else hidden_channels
  46. )
  47. res_skip_layer = nn.Conv1d(hidden_channels, res_skip_channels, 1)
  48. res_skip_layer = weight_norm(res_skip_layer, name="weight")
  49. self.res_skip_layers.append(res_skip_layer)
  50. self.out_channels = out_channels
  51. if out_channels is not None:
  52. self.out_layer = nn.Conv1d(hidden_channels, out_channels, 1)
  53. def forward(self, x, x_mask=None):
  54. n_channels_tensor = torch.IntTensor([self.hidden_channels])
  55. if self.in_channels is not None:
  56. x = self.proj_in(x)
  57. output = torch.zeros_like(x)
  58. for i in range(self.n_layers):
  59. x_in = self.in_layers[i](x)
  60. acts = fused_add_tanh_sigmoid_multiply(x_in, n_channels_tensor)
  61. acts = self.drop(acts)
  62. res_skip_acts = self.res_skip_layers[i](acts)
  63. if i < self.n_layers - 1:
  64. res_acts = res_skip_acts[:, : self.hidden_channels, :]
  65. x = x + res_acts
  66. if x_mask is not None:
  67. x = x * x_mask
  68. output = output + res_skip_acts[:, self.hidden_channels :, :]
  69. else:
  70. output = output + res_skip_acts
  71. if x_mask is not None:
  72. x = output * x_mask
  73. if self.out_channels is not None:
  74. x = self.out_layer(x)
  75. return x
  76. def remove_weight_norm(self):
  77. if self.gin_channels != 0:
  78. remove_parametrizations(self.cond_layer)
  79. for l in self.in_layers:
  80. remove_parametrizations(l)
  81. for l in self.res_skip_layers:
  82. remove_parametrizations(l)