modules.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import torch
  2. import torch.nn as nn
  3. from torch.nn.utils.parametrizations import weight_norm
  4. from torch.nn.utils.parametrize import remove_parametrizations
  5. from fish_speech.models.vqgan.utils import fused_add_tanh_sigmoid_multiply
  6. LRELU_SLOPE = 0.1
  7. # ! PosteriorEncoder
  8. # ! ResidualCouplingLayer
  9. class WN(nn.Module):
  10. def __init__(
  11. self,
  12. hidden_channels,
  13. kernel_size,
  14. dilation_rate,
  15. n_layers,
  16. gin_channels=0,
  17. p_dropout=0,
  18. ):
  19. super(WN, self).__init__()
  20. assert kernel_size % 2 == 1
  21. self.hidden_channels = hidden_channels
  22. self.kernel_size = (kernel_size,)
  23. self.n_layers = n_layers
  24. self.gin_channels = gin_channels
  25. self.in_layers = nn.ModuleList()
  26. self.res_skip_layers = nn.ModuleList()
  27. self.drop = nn.Dropout(p_dropout)
  28. if gin_channels != 0:
  29. cond_layer = nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
  30. self.cond_layer = weight_norm(cond_layer, name="weight")
  31. for i in range(n_layers):
  32. dilation = dilation_rate**i
  33. padding = int((kernel_size * dilation - dilation) / 2)
  34. in_layer = nn.Conv1d(
  35. hidden_channels,
  36. 2 * hidden_channels,
  37. kernel_size,
  38. dilation=dilation,
  39. padding=padding,
  40. )
  41. in_layer = weight_norm(in_layer, name="weight")
  42. self.in_layers.append(in_layer)
  43. # last one is not necessary
  44. res_skip_channels = (
  45. 2 * hidden_channels if i < n_layers - 1 else hidden_channels
  46. )
  47. res_skip_layer = nn.Conv1d(hidden_channels, res_skip_channels, 1)
  48. res_skip_layer = weight_norm(res_skip_layer, name="weight")
  49. self.res_skip_layers.append(res_skip_layer)
  50. def forward(self, x, x_mask, g=None, **kwargs):
  51. output = torch.zeros_like(x)
  52. n_channels_tensor = torch.IntTensor([self.hidden_channels])
  53. if g is not None:
  54. g = self.cond_layer(g)
  55. for i in range(self.n_layers):
  56. x_in = self.in_layers[i](x)
  57. if g is not None:
  58. cond_offset = i * 2 * self.hidden_channels
  59. g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
  60. else:
  61. g_l = torch.zeros_like(x_in)
  62. acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
  63. acts = self.drop(acts)
  64. res_skip_acts = self.res_skip_layers[i](acts)
  65. if i < self.n_layers - 1:
  66. res_acts = res_skip_acts[:, : self.hidden_channels, :]
  67. x = (x + res_acts) * x_mask
  68. output = output + res_skip_acts[:, self.hidden_channels :, :]
  69. else:
  70. output = output + res_skip_acts
  71. return output * x_mask
  72. def remove_weight_norm(self):
  73. if self.gin_channels != 0:
  74. remove_parametrizations(self.cond_layer)
  75. for l in self.in_layers:
  76. remove_parametrizations(l)
  77. for l in self.res_skip_layers:
  78. remove_parametrizations(l)
  79. # ! StochasticDurationPredictor
  80. # ! ResidualCouplingBlock
  81. # TODO convert to class method
  82. class Flip(nn.Module):
  83. def forward(self, x, *args, reverse=False, **kwargs):
  84. x = torch.flip(x, [1])
  85. if not reverse:
  86. logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
  87. return x, logdet
  88. else:
  89. return x