| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.nn.utils import spectral_norm, weight_norm
- from fish_speech.models.vqgan.modules.modules import LRELU_SLOPE
- from fish_speech.models.vqgan.utils import get_padding
- class DiscriminatorP(nn.Module):
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
- super(DiscriminatorP, self).__init__()
- self.period = period
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
- self.convs = nn.ModuleList(
- [
- norm_f(
- nn.Conv2d(
- 1,
- 32,
- (kernel_size, 1),
- (stride, 1),
- padding=(get_padding(kernel_size, 1), 0),
- )
- ),
- norm_f(
- nn.Conv2d(
- 32,
- 128,
- (kernel_size, 1),
- (stride, 1),
- padding=(get_padding(kernel_size, 1), 0),
- )
- ),
- norm_f(
- nn.Conv2d(
- 128,
- 512,
- (kernel_size, 1),
- (stride, 1),
- padding=(get_padding(kernel_size, 1), 0),
- )
- ),
- norm_f(
- nn.Conv2d(
- 512,
- 1024,
- (kernel_size, 1),
- (stride, 1),
- padding=(get_padding(kernel_size, 1), 0),
- )
- ),
- norm_f(
- nn.Conv2d(
- 1024,
- 1024,
- (kernel_size, 1),
- 1,
- padding=(get_padding(kernel_size, 1), 0),
- )
- ),
- ]
- )
- self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
- def forward(self, x):
- fmap = []
- # 1d to 2d
- b, c, t = x.shape
- if t % self.period != 0: # pad first
- n_pad = self.period - (t % self.period)
- x = F.pad(x, (0, n_pad), "reflect")
- t = t + n_pad
- x = x.view(b, c, t // self.period, self.period)
- for l in self.convs:
- x = l(x)
- x = F.leaky_relu(x, LRELU_SLOPE)
- fmap.append(x)
- x = self.conv_post(x)
- fmap.append(x)
- x = torch.flatten(x, 1, -1)
- return x, fmap
- class DiscriminatorS(nn.Module):
- def __init__(self, use_spectral_norm=False):
- super(DiscriminatorS, self).__init__()
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
- self.convs = nn.ModuleList(
- [
- norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
- norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
- norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
- norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
- norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
- norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
- norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
- ]
- )
- self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
- def forward(self, x):
- fmap = []
- for l in self.convs:
- x = l(x)
- x = F.leaky_relu(x, LRELU_SLOPE)
- fmap.append(x)
- x = self.conv_post(x)
- fmap.append(x)
- x = torch.flatten(x, 1, -1)
- return x, fmap
- class EnsembleDiscriminator(nn.Module):
- def __init__(self, ckpt_path=None):
- super(EnsembleDiscriminator, self).__init__()
- periods = [2, 3, 5, 7, 11] # [1, 2, 3, 5, 7, 11]
- discs = [DiscriminatorS(use_spectral_norm=True)]
- discs = discs + [DiscriminatorP(i, use_spectral_norm=False) for i in periods]
- self.discriminators = nn.ModuleList(discs)
- if ckpt_path is not None:
- self.restore_from_ckpt(ckpt_path)
- def restore_from_ckpt(self, ckpt_path):
- ckpt = torch.load(ckpt_path, map_location="cpu")
- mpd, msd = ckpt["mpd"], ckpt["msd"]
- all_keys = {}
- for k, v in mpd.items():
- keys = k.split(".")
- keys[1] = str(int(keys[1]) + 1)
- all_keys[".".join(keys)] = v
- for k, v in msd.items():
- if not k.startswith("discriminators.0"):
- continue
- all_keys[k] = v
- self.load_state_dict(all_keys, strict=True)
- def forward(self, y, y_hat):
- y_d_rs = []
- y_d_gs = []
- fmap_rs = []
- fmap_gs = []
- for i, d in enumerate(self.discriminators):
- y_d_r, fmap_r = d(y)
- y_d_g, fmap_g = d(y_hat)
- y_d_rs.append(y_d_r)
- y_d_gs.append(y_d_g)
- fmap_rs.append(fmap_r)
- fmap_gs.append(fmap_g)
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
- if __name__ == "__main__":
- m = EnsembleDiscriminator(
- ckpt_path="checkpoints/hifigan-v1-universal-22050/do_02500000"
- )
|