| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.nn.utils.parametrizations import weight_norm
- from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
- from fish_speech.models.vqgan.modules.modules import LRELU_SLOPE
- from fish_speech.models.vqgan.utils import get_padding, init_weights
- class Generator(nn.Module):
- def __init__(
- self,
- initial_channel,
- resblock,
- resblock_kernel_sizes,
- resblock_dilation_sizes,
- upsample_rates,
- upsample_initial_channel,
- upsample_kernel_sizes,
- gin_channels=0,
- ckpt_path=None,
- ):
- super(Generator, self).__init__()
- self.num_kernels = len(resblock_kernel_sizes)
- self.num_upsamples = len(upsample_rates)
- self.conv_pre = weight_norm(
- nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
- )
- resblock = ResBlock1 if resblock == "1" else ResBlock2
- self.ups = nn.ModuleList()
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
- self.ups.append(
- weight_norm(
- nn.ConvTranspose1d(
- upsample_initial_channel // (2**i),
- upsample_initial_channel // (2 ** (i + 1)),
- k,
- u,
- padding=(k - u) // 2,
- )
- )
- )
- self.resblocks = nn.ModuleList()
- for i in range(len(self.ups)):
- ch = upsample_initial_channel // (2 ** (i + 1))
- for j, (k, d) in enumerate(
- zip(resblock_kernel_sizes, resblock_dilation_sizes)
- ):
- self.resblocks.append(resblock(ch, k, d))
- self.conv_post = weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3))
- self.ups.apply(init_weights)
- if gin_channels != 0:
- self.cond = nn.Linear(gin_channels, upsample_initial_channel)
- if ckpt_path is not None:
- self.load_state_dict(torch.load(ckpt_path)["generator"], strict=True)
- def forward(self, x, g=None):
- x = self.conv_pre(x)
- if g is not None:
- x = x + self.cond(g.mT).mT
- for i in range(self.num_upsamples):
- x = F.leaky_relu(x, LRELU_SLOPE)
- x = self.ups[i](x)
- xs = None
- for j in range(self.num_kernels):
- if xs is None:
- xs = self.resblocks[i * self.num_kernels + j](x)
- else:
- xs += self.resblocks[i * self.num_kernels + j](x)
- x = xs / self.num_kernels
- x = F.leaky_relu(x)
- x = self.conv_post(x)
- x = torch.tanh(x)
- return x
- def remove_weight_norm(self):
- print("Removing weight norm...")
- for l in self.ups:
- remove_weight_norm(l)
- for l in self.resblocks:
- l.remove_weight_norm()
- class ResBlock1(nn.Module):
- def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
- super(ResBlock1, self).__init__()
- self.convs1 = nn.ModuleList(
- [
- weight_norm(
- nn.Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=dilation[0],
- padding=get_padding(kernel_size, dilation[0]),
- )
- ),
- weight_norm(
- nn.Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=dilation[1],
- padding=get_padding(kernel_size, dilation[1]),
- )
- ),
- weight_norm(
- nn.Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=dilation[2],
- padding=get_padding(kernel_size, dilation[2]),
- )
- ),
- ]
- )
- self.convs1.apply(init_weights)
- self.convs2 = nn.ModuleList(
- [
- weight_norm(
- nn.Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=1,
- padding=get_padding(kernel_size, 1),
- )
- ),
- weight_norm(
- nn.Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=1,
- padding=get_padding(kernel_size, 1),
- )
- ),
- weight_norm(
- nn.Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=1,
- padding=get_padding(kernel_size, 1),
- )
- ),
- ]
- )
- self.convs2.apply(init_weights)
- def forward(self, x, x_mask=None):
- for c1, c2 in zip(self.convs1, self.convs2):
- xt = F.leaky_relu(x, LRELU_SLOPE)
- if x_mask is not None:
- xt = xt * x_mask
- xt = c1(xt)
- xt = F.leaky_relu(xt, LRELU_SLOPE)
- if x_mask is not None:
- xt = xt * x_mask
- xt = c2(xt)
- x = xt + x
- if x_mask is not None:
- x = x * x_mask
- return x
- def remove_weight_norm(self):
- for l in self.convs1:
- remove_weight_norm(l)
- for l in self.convs2:
- remove_weight_norm(l)
- class ResBlock2(nn.Module):
- def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
- super(ResBlock2, self).__init__()
- self.convs = nn.ModuleList(
- [
- weight_norm(
- nn.Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=dilation[0],
- padding=get_padding(kernel_size, dilation[0]),
- )
- ),
- weight_norm(
- nn.Conv1d(
- channels,
- channels,
- kernel_size,
- 1,
- dilation=dilation[1],
- padding=get_padding(kernel_size, dilation[1]),
- )
- ),
- ]
- )
- self.convs.apply(init_weights)
- def forward(self, x, x_mask=None):
- for c in self.convs:
- xt = F.leaky_relu(x, LRELU_SLOPE)
- if x_mask is not None:
- xt = xt * x_mask
- xt = c(xt)
- x = xt + x
- if x_mask is not None:
- x = x * x_mask
- return x
- def remove_weight_norm(self):
- for l in self.convs:
- remove_weight_norm(l)
- if __name__ == "__main__":
- import librosa
- import soundfile as sf
- from fish_speech.models.vqgan.spectrogram import LogMelSpectrogram
- gen = Generator(
- 80,
- "1",
- [3, 7, 11],
- [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
- [8, 8, 2, 2],
- 512,
- [16, 16, 4, 4],
- ckpt_path="checkpoints/hifigan-v1-universal-22050/g_02500000",
- )
- spec = LogMelSpectrogram(
- sample_rate=22050,
- n_fft=1024,
- win_length=1024,
- hop_length=256,
- n_mels=80,
- f_min=0.0,
- f_max=8000.0,
- )
- audio = librosa.load("data/StarRail/Chinese/符玄/archive_fuxuan_9.wav", sr=22050)[0]
- audio = torch.from_numpy(audio).unsqueeze(0)
- spec = spec(audio)
- print(spec.shape)
- audio = gen(spec)
- print(audio.shape)
- sf.write("test.wav", audio.detach().squeeze().numpy(), 22050)
|