discriminator.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch.nn.utils import spectral_norm, weight_norm
  5. from fish_speech.models.vqgan.modules.modules import LRELU_SLOPE
  6. from fish_speech.models.vqgan.utils import get_padding
  7. class DiscriminatorP(nn.Module):
  8. def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
  9. super(DiscriminatorP, self).__init__()
  10. self.period = period
  11. norm_f = weight_norm if use_spectral_norm == False else spectral_norm
  12. self.convs = nn.ModuleList(
  13. [
  14. norm_f(
  15. nn.Conv2d(
  16. 1,
  17. 32,
  18. (kernel_size, 1),
  19. (stride, 1),
  20. padding=(get_padding(kernel_size, 1), 0),
  21. )
  22. ),
  23. norm_f(
  24. nn.Conv2d(
  25. 32,
  26. 128,
  27. (kernel_size, 1),
  28. (stride, 1),
  29. padding=(get_padding(kernel_size, 1), 0),
  30. )
  31. ),
  32. norm_f(
  33. nn.Conv2d(
  34. 128,
  35. 512,
  36. (kernel_size, 1),
  37. (stride, 1),
  38. padding=(get_padding(kernel_size, 1), 0),
  39. )
  40. ),
  41. norm_f(
  42. nn.Conv2d(
  43. 512,
  44. 1024,
  45. (kernel_size, 1),
  46. (stride, 1),
  47. padding=(get_padding(kernel_size, 1), 0),
  48. )
  49. ),
  50. norm_f(
  51. nn.Conv2d(
  52. 1024,
  53. 1024,
  54. (kernel_size, 1),
  55. 1,
  56. padding=(get_padding(kernel_size, 1), 0),
  57. )
  58. ),
  59. ]
  60. )
  61. self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
  62. def forward(self, x):
  63. fmap = []
  64. # 1d to 2d
  65. b, c, t = x.shape
  66. if t % self.period != 0: # pad first
  67. n_pad = self.period - (t % self.period)
  68. x = F.pad(x, (0, n_pad), "reflect")
  69. t = t + n_pad
  70. x = x.view(b, c, t // self.period, self.period)
  71. for l in self.convs:
  72. x = l(x)
  73. x = F.leaky_relu(x, LRELU_SLOPE)
  74. fmap.append(x)
  75. x = self.conv_post(x)
  76. fmap.append(x)
  77. x = torch.flatten(x, 1, -1)
  78. return x, fmap
  79. class DiscriminatorS(nn.Module):
  80. def __init__(self, use_spectral_norm=False):
  81. super(DiscriminatorS, self).__init__()
  82. norm_f = weight_norm if use_spectral_norm == False else spectral_norm
  83. self.convs = nn.ModuleList(
  84. [
  85. norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
  86. norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
  87. norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
  88. norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
  89. norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
  90. norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
  91. norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
  92. ]
  93. )
  94. self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
  95. def forward(self, x):
  96. fmap = []
  97. for l in self.convs:
  98. x = l(x)
  99. x = F.leaky_relu(x, LRELU_SLOPE)
  100. fmap.append(x)
  101. x = self.conv_post(x)
  102. fmap.append(x)
  103. x = torch.flatten(x, 1, -1)
  104. return x, fmap
  105. class EnsembleDiscriminator(nn.Module):
  106. def __init__(self, ckpt_path=None):
  107. super(EnsembleDiscriminator, self).__init__()
  108. periods = [2, 3, 5, 7, 11] # [1, 2, 3, 5, 7, 11]
  109. discs = [DiscriminatorS(use_spectral_norm=True)]
  110. discs = discs + [DiscriminatorP(i, use_spectral_norm=False) for i in periods]
  111. self.discriminators = nn.ModuleList(discs)
  112. if ckpt_path is not None:
  113. self.restore_from_ckpt(ckpt_path)
  114. def restore_from_ckpt(self, ckpt_path):
  115. ckpt = torch.load(ckpt_path, map_location="cpu")
  116. mpd, msd = ckpt["mpd"], ckpt["msd"]
  117. all_keys = {}
  118. for k, v in mpd.items():
  119. keys = k.split(".")
  120. keys[1] = str(int(keys[1]) + 1)
  121. all_keys[".".join(keys)] = v
  122. for k, v in msd.items():
  123. if not k.startswith("discriminators.0"):
  124. continue
  125. all_keys[k] = v
  126. self.load_state_dict(all_keys, strict=True)
  127. def forward(self, y, y_hat):
  128. y_d_rs = []
  129. y_d_gs = []
  130. fmap_rs = []
  131. fmap_gs = []
  132. for i, d in enumerate(self.discriminators):
  133. y_d_r, fmap_r = d(y)
  134. y_d_g, fmap_g = d(y_hat)
  135. y_d_rs.append(y_d_r)
  136. y_d_gs.append(y_d_g)
  137. fmap_rs.append(fmap_r)
  138. fmap_gs.append(fmap_g)
  139. return y_d_rs, y_d_gs, fmap_rs, fmap_gs
  140. if __name__ == "__main__":
  141. m = EnsembleDiscriminator(
  142. ckpt_path="checkpoints/hifigan-v1-universal-22050/do_02500000"
  143. )