modules.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import torch
  2. import torch.nn as nn
  3. from torch.nn.utils.parametrizations import weight_norm
  4. from torch.nn.utils.parametrize import remove_parametrizations
  5. from fish_speech.models.vqgan.utils import fused_add_tanh_sigmoid_multiply
  6. LRELU_SLOPE = 0.1
  7. # ! PosteriorEncoder
  8. # ! ResidualCouplingLayer
  9. class WN(nn.Module):
  10. def __init__(
  11. self,
  12. hidden_channels,
  13. kernel_size,
  14. dilation_rate,
  15. n_layers,
  16. gin_channels=0,
  17. p_dropout=0,
  18. out_channels=None,
  19. ):
  20. super(WN, self).__init__()
  21. assert kernel_size % 2 == 1
  22. self.hidden_channels = hidden_channels
  23. self.kernel_size = (kernel_size,)
  24. self.n_layers = n_layers
  25. self.gin_channels = gin_channels
  26. self.in_layers = nn.ModuleList()
  27. self.res_skip_layers = nn.ModuleList()
  28. self.drop = nn.Dropout(p_dropout)
  29. if gin_channels != 0:
  30. cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
  31. self.cond_layer = weight_norm(cond_layer, name="weight")
  32. for i in range(n_layers):
  33. dilation = dilation_rate**i
  34. padding = int((kernel_size * dilation - dilation) / 2)
  35. in_layer = nn.Conv1d(
  36. hidden_channels,
  37. 2 * hidden_channels,
  38. kernel_size,
  39. dilation=dilation,
  40. padding=padding,
  41. )
  42. in_layer = weight_norm(in_layer, name="weight")
  43. self.in_layers.append(in_layer)
  44. # last one is not necessary
  45. res_skip_channels = (
  46. 2 * hidden_channels if i < n_layers - 1 else hidden_channels
  47. )
  48. res_skip_layer = nn.Conv1d(hidden_channels, res_skip_channels, 1)
  49. res_skip_layer = weight_norm(res_skip_layer, name="weight")
  50. self.res_skip_layers.append(res_skip_layer)
  51. self.out_channels = out_channels
  52. if out_channels is not None:
  53. self.out_layer = nn.Conv1d(hidden_channels, out_channels, 1)
  54. def forward(self, x, x_mask, g=None, **kwargs):
  55. output = torch.zeros_like(x)
  56. n_channels_tensor = torch.IntTensor([self.hidden_channels])
  57. if g is not None:
  58. g = self.cond_layer(g)
  59. for i in range(self.n_layers):
  60. x_in = self.in_layers[i](x)
  61. if g is not None:
  62. cond_offset = i * 2 * self.hidden_channels
  63. g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
  64. else:
  65. g_l = torch.zeros_like(x_in)
  66. acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
  67. acts = self.drop(acts)
  68. res_skip_acts = self.res_skip_layers[i](acts)
  69. if i < self.n_layers - 1:
  70. res_acts = res_skip_acts[:, : self.hidden_channels, :]
  71. x = (x + res_acts) * x_mask
  72. output = output + res_skip_acts[:, self.hidden_channels :, :]
  73. else:
  74. output = output + res_skip_acts
  75. x = output * x_mask
  76. if self.out_channels is not None:
  77. x = self.out_layer(x)
  78. return x
  79. def remove_weight_norm(self):
  80. if self.gin_channels != 0:
  81. remove_parametrizations(self.cond_layer)
  82. for l in self.in_layers:
  83. remove_parametrizations(l)
  84. for l in self.res_skip_layers:
  85. remove_parametrizations(l)
  86. # ! StochasticDurationPredictor
  87. # ! ResidualCouplingBlock
  88. # TODO convert to class method
  89. class Flip(nn.Module):
  90. def forward(self, x, *args, reverse=False, **kwargs):
  91. x = torch.flip(x, [1])
  92. if not reverse:
  93. logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
  94. return x, logdet
  95. else:
  96. return x