Explorar el Código

Remove normalize flow from VITS

Lengyue hace 2 años
padre
commit
4ff12c2e7d

+ 37 - 11
fish_speech/configs/hubert_vq.yaml

@@ -11,12 +11,15 @@ trainer:
   strategy:
     _target_: lightning.pytorch.strategies.DDPStrategy
     static_graph: true
-  precision: 16-mixed
+  precision: 32
   max_steps: 1_000_000
+  val_check_interval: 1000
 
 sample_rate: 32000
 hop_length: 640
 num_mels: 128
+n_fft: 2048
+win_length: 2048
 
 # Dataset Configuration
 train_dataset:
@@ -31,7 +34,7 @@ val_dataset:
   filelist: data/vq_val_filelist.txt
   sample_rate: ${sample_rate}
   hop_length: ${hop_length}
-  slice_frames: null
+  slice_frames: 256
 
 data:
   _target_: fish_speech.datasets.vqgan.VQGANDataModule
@@ -48,22 +51,45 @@ model:
   hop_length: ${hop_length}
   segment_size: 20480
 
-  encoder:
-    _target_: fish_speech.models.vqgan.modules.VQEncoder
+  semantic_encoder:
+    _target_: fish_speech.models.vqgan.modules.SemanticEncoder
     in_channels: 1024
-    channels: 192
-    num_mels: ${num_mels}
+    hidden_channels: 384
+    out_channels: 192
     num_heads: 2
-    num_feature_layers: 2
-    num_speaker_layers: 4
-    num_mixin_layers: 4
+    num_layers: 8
     input_downsample: true
     code_book_size: 2048
     freeze_vq: false
+    gin_channels: ${model.speaker_encoder.out_channels}
+
+  posterior_encoder:
+    _target_: fish_speech.models.vqgan.modules.PosteriorEncoder
+    in_channels: "${eval: '${n_fft} // 2 + 1'}"
+    hidden_channels: 192
+    out_channels: 192
+    gin_channels: ${model.speaker_encoder.out_channels}
+  
+  speaker_encoder:
+    _target_: fish_speech.models.vqgan.modules.SpeakerEncoder
+    in_channels: ${num_mels}
+    hidden_channels: 192
+    out_channels: 512
+
+  # flow:
+  #   _target_: fish_speech.models.vqgan.modules.ResidualCouplingBlock
+  #   channels: 192
+  #   hidden_channels: 192
+  #   kernel_size: 5
+  #   dilation_rate: 1
+  #   n_layers: 4
+  #   n_flows: 4
+  #   gin_channels: ${model.speaker_encoder.out_channels}
 
   generator:
     _target_: fish_speech.models.vqgan.modules.Generator
     initial_channel: 192
+    gin_channels: ${model.speaker_encoder.out_channels}
     resblock: "1"
     resblock_kernel_sizes: [3, 7, 11]
     resblock_dilation_sizes: 
@@ -80,9 +106,9 @@ model:
   mel_transform:
     _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
     sample_rate: ${sample_rate}
-    n_fft: 2048
+    n_fft: ${n_fft}
     hop_length: ${hop_length}
-    win_length: 2048
+    win_length: ${win_length}
     n_mels: ${num_mels}
 
   optimizer:

+ 130 - 58
fish_speech/models/vqgan/lit_module.py

@@ -9,7 +9,13 @@ from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
 from matplotlib import pyplot as plt
 from torch import nn
 
-from fish_speech.models.vqgan.modules import EnsembleDiscriminator, Generator, VQEncoder
+from fish_speech.models.vqgan.modules import (
+    EnsembleDiscriminator,
+    Generator,
+    PosteriorEncoder,
+    SemanticEncoder,
+    SpeakerEncoder,
+)
 from fish_speech.models.vqgan.utils import plot_mel, sequence_mask
 
 
@@ -18,7 +24,10 @@ class VQGAN(L.LightningModule):
         self,
         optimizer: Callable,
         lr_scheduler: Callable,
-        encoder: VQEncoder,
+        semantic_encoder: SemanticEncoder,
+        posterior_encoder: PosteriorEncoder,
+        speaker_encoder: SpeakerEncoder,
+        # flow: nn.Module,
         generator: Generator,
         discriminator: EnsembleDiscriminator,
         mel_transform: nn.Module,
@@ -34,7 +43,10 @@ class VQGAN(L.LightningModule):
 
         # Generator and discriminators
         # Compile generator so that snake can save memory
-        self.encoder = encoder
+        self.semantic_encoder = semantic_encoder
+        self.posterior_encoder = posterior_encoder
+        self.speaker_encoder = speaker_encoder
+        # self.flow = flow
         self.generator = generator
         self.discriminator = discriminator
         self.mel_transform = mel_transform
@@ -50,7 +62,13 @@ class VQGAN(L.LightningModule):
     def configure_optimizers(self):
         # Need two optimizers and two schedulers
         optimizer_generator = self.optimizer_builder(
-            itertools.chain(self.encoder.parameters(), self.generator.parameters())
+            itertools.chain(
+                self.semantic_encoder.parameters(),
+                self.posterior_encoder.parameters(),
+                self.speaker_encoder.parameters(),
+                self.generator.parameters(),
+                # self.flow.parameters(),
+            )
         )
         optimizer_discriminator = self.optimizer_builder(
             self.discriminator.parameters()
@@ -117,6 +135,31 @@ class VQGAN(L.LightningModule):
 
         return loss * 2
 
+    @staticmethod
+    def kl_loss(m_q, logs_q, m_p, logs_p, z_mask):
+        """
+        m_q, logs_q: [b, h, t_t]
+        m_p, logs_p: [b, h, t_t]
+        """
+        m_q = m_q.float()
+        logs_q = logs_q.float()
+        m_p = m_p.float()
+        logs_p = logs_p.float()
+        z_mask = z_mask.float()
+
+        kl = 0.5 * (
+            (m_q - m_p) ** 2 / torch.exp(logs_p)
+            + torch.exp(logs_q) / torch.exp(logs_p)
+            - 1
+            - logs_q
+            + logs_p
+        )
+
+        kl = torch.sum(kl * z_mask)
+        l = kl / torch.sum(z_mask)
+
+        return l
+
     def training_step(self, batch, batch_idx):
         optim_g, optim_d = self.optimizers()
 
@@ -127,7 +170,8 @@ class VQGAN(L.LightningModule):
         features = features.float()
 
         with torch.no_grad():
-            gt_mels = self.mel_transform(audios).transpose(1, 2)
+            gt_mels, gt_specs = self.mel_transform(audios, return_linear=True)
+            gt_mels = gt_mels.transpose(1, 2)
             key_padding_mask = sequence_mask(feature_lengths)
             mels_key_padding_mask = sequence_mask(audio_lengths // self.hop_length)
             audio_masks = sequence_mask(audio_lengths)[:, None]
@@ -135,6 +179,7 @@ class VQGAN(L.LightningModule):
             assert abs(gt_mels.shape[1] - mels_key_padding_mask.shape[1]) <= 1
             gt_mel_length = min(gt_mels.shape[1], mels_key_padding_mask.shape[1])
             gt_mels = gt_mels[:, :gt_mel_length]
+            gt_specs = gt_specs[:, :, :gt_mel_length]
             mels_key_padding_mask = mels_key_padding_mask[:, :gt_mel_length]
 
             assert abs(features.shape[1] - key_padding_mask.shape[1]) <= 1
@@ -144,39 +189,19 @@ class VQGAN(L.LightningModule):
 
         audios = audios[:, None, :]
 
-        # # Get slice of audio
-        # if audios.shape[-1] > self.segment_size:
-        #     start = torch.randint(
-        #         0, audios.shape[-1] - self.segment_size, (1,), device=audios.device
-        #     ).item()
-        #     start = start // self.hop_length * self.hop_length
-
-        #     audios = audios[:, :, start : start + self.segment_size]
-        #     audio_masks = sequence_mask(audio_lengths)[
-        #         :, None, start : start + self.segment_size
-        #     ]
-
-        #     mel_start = start // self.hop_length
-        #     mel_size = self.segment_size // self.hop_length
-        #     gt_mels = gt_mels[:, mel_start : mel_start + mel_size]
-        #     mels_key_padding_mask = mels_key_padding_mask[
-        #         :, mel_start : mel_start + mel_size
-        #     ]
-
-        #     features = features[:, :, mel_start : mel_start + mel_size]
-
-        # Generator
-        encoded = self.encoder(
+        speaker = self.speaker_encoder(gt_mels, mels_key_padding_mask)[:, :, None]
+        prior = self.semantic_encoder(
             x=features,
-            mels=gt_mels,
             key_padding_mask=key_padding_mask,
-            mels_key_padding_mask=mels_key_padding_mask,
+            g=speaker,
         )
 
-        features = encoded.features
-        # features = self.naive_proj(features.transpose(1, 2))
-
-        fake_audios = self.generator(features)
+        posterior_key_padding_mask = (~mels_key_padding_mask).float()[:, None]
+        posterior = self.posterior_encoder(
+            gt_specs, posterior_key_padding_mask, g=speaker
+        )
+        # z_p = self.flow(posterior.mean, posterior_key_padding_mask, g=speaker)
+        fake_audios = self.generator(posterior.z, g=speaker)
 
         min_audio_length = min(audios.shape[-1], fake_audios.shape[-1])
         audios = audios[:, :, :min_audio_length]
@@ -227,8 +252,15 @@ class VQGAN(L.LightningModule):
             loss_mel = F.l1_loss(gt_mels, fake_mels)
             loss_adv, _ = self.generator_loss(y_d_hat_g)
             loss_fm = self.feature_loss(fmap_r, fmap_g)
+            loss_kl = self.kl_loss(
+                posterior.mean,
+                posterior.logs,
+                prior.mean,
+                prior.logs,
+                posterior_key_padding_mask,
+            )
 
-            loss_gen_all = loss_mel * 45 + loss_fm + loss_adv + encoded.loss
+            loss_gen_all = loss_mel * 45 + loss_fm + loss_adv + prior.loss + loss_kl
 
         self.log(
             "train/generator/loss",
@@ -266,9 +298,18 @@ class VQGAN(L.LightningModule):
             logger=True,
             sync_dist=True,
         )
+        self.log(
+            "train/generator/loss_kl",
+            loss_kl,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
         self.log(
             "train/generator/loss_vq",
-            encoded.loss,
+            prior.loss,
             on_step=True,
             on_epoch=False,
             prog_bar=False,
@@ -296,14 +337,15 @@ class VQGAN(L.LightningModule):
         features = features.float()
 
         with torch.no_grad():
-            gt_mels = self.mel_transform(audios).transpose(1, 2)
+            gt_mels, gt_specs = self.mel_transform(audios, return_linear=True)
+            gt_mels = gt_mels.transpose(1, 2)
             key_padding_mask = sequence_mask(feature_lengths)
             mels_key_padding_mask = sequence_mask(audio_lengths // self.hop_length)
-            audio_masks = sequence_mask(audio_lengths)
 
             assert abs(gt_mels.shape[1] - mels_key_padding_mask.shape[1]) <= 1
             gt_mel_length = min(gt_mels.shape[1], mels_key_padding_mask.shape[1])
             gt_mels = gt_mels[:, :gt_mel_length]
+            gt_specs = gt_specs[:, :, :gt_mel_length]
             mels_key_padding_mask = mels_key_padding_mask[:, :gt_mel_length]
 
             assert abs(features.shape[1] - key_padding_mask.shape[1]) <= 1
@@ -312,37 +354,41 @@ class VQGAN(L.LightningModule):
             key_padding_mask = key_padding_mask[:, :gt_feature_length]
 
         # Generator
-        encoded = self.encoder(
+        # speaker: (B, C, 1)
+        speaker = self.speaker_encoder(gt_mels, mels_key_padding_mask)[:, :, None]
+        posterior_key_padding_mask = (~mels_key_padding_mask).float()[:, None]
+
+        z_gen = self.semantic_encoder(
             x=features,
-            mels=gt_mels,
             key_padding_mask=key_padding_mask,
-            mels_key_padding_mask=mels_key_padding_mask,
-        )
+            g=speaker,
+        ).z
 
-        # features = self.naive_proj(features.transpose(1, 2))
+        # z_gen = self.flow(z_gen, posterior_key_padding_mask, g=speaker, reverse=True)
 
-        features = encoded.features
-        audios = audios[:, None, :]
+        z_posterior = self.posterior_encoder(
+            gt_specs, posterior_key_padding_mask, g=speaker
+        ).mean
 
-        fake_audios = self.generator(features)
-        min_audio_length = min(audios.shape[-1], fake_audios.shape[-1])
+        audios = audios[:, None, :]
+        fake_audios = self.generator(z_gen, g=speaker)
+        posterior_audios = self.generator(z_posterior)
+        min_audio_length = min(
+            audios.shape[-1], fake_audios.shape[-1], posterior_audios.shape[-1]
+        )
 
         audios = audios[:, :, :min_audio_length]
         fake_audios = fake_audios[:, :, :min_audio_length]
-        audio_masks = audio_masks[:, None, :min_audio_length]
-
-        audio = torch.masked_fill(audios, audio_masks, 0.0)
-        fake_audios = torch.masked_fill(fake_audios, audio_masks, 0.0)
-        assert fake_audios.shape == audio.shape
+        posterior_audios = posterior_audios[:, :, :min_audio_length]
+        assert fake_audios.shape == audios.shape == posterior_audios.shape
 
         fake_mels = self.mel_transform(fake_audios.squeeze(1)).transpose(1, 2)
+        posterior_mels = self.mel_transform(posterior_audios.squeeze(1)).transpose(1, 2)
+
         min_mel_length = min(gt_mels.shape[1], fake_mels.shape[1])
         gt_mels = gt_mels[:, :min_mel_length]
         fake_mels = fake_mels[:, :min_mel_length]
-        mels_key_padding_mask = mels_key_padding_mask[:, :min_mel_length]
-
-        gt_mels = torch.masked_fill(gt_mels, mels_key_padding_mask[:, :, None], 0.0)
-        fake_mels = torch.masked_fill(fake_mels, mels_key_padding_mask[:, :, None], 0.0)
+        posterior_mels = posterior_mels[:, :min_mel_length]
 
         mel_loss = F.l1_loss(gt_mels, fake_mels)
         self.log(
@@ -355,12 +401,22 @@ class VQGAN(L.LightningModule):
             sync_dist=True,
         )
 
-        for idx, (mel, gen_mel, audio, gen_audio, audio_len) in enumerate(
+        for idx, (
+            mel,
+            gen_mel,
+            post_mel,
+            audio,
+            gen_audio,
+            post_audio,
+            audio_len,
+        ) in enumerate(
             zip(
                 gt_mels.transpose(1, 2),
                 fake_mels.transpose(1, 2),
+                posterior_mels.transpose(1, 2),
                 audios,
                 fake_audios,
+                posterior_audios,
                 audio_lengths,
             )
         ):
@@ -369,9 +425,14 @@ class VQGAN(L.LightningModule):
             image_mels = plot_mel(
                 [
                     gen_mel[:, :mel_len],
+                    post_mel[:, :mel_len],
                     mel[:, :mel_len],
                 ],
-                ["Sampled Spectrogram", "Ground-Truth Spectrogram"],
+                [
+                    "Generated Spectrogram",
+                    "Posterior Spectrogram",
+                    "Ground-Truth Spectrogram",
+                ],
             )
 
             if isinstance(self.logger, WandbLogger):
@@ -389,6 +450,11 @@ class VQGAN(L.LightningModule):
                                 sample_rate=self.sampling_rate,
                                 caption="prediction",
                             ),
+                            wandb.Audio(
+                                post_audio[0, :audio_len],
+                                sample_rate=self.sampling_rate,
+                                caption="posterior",
+                            ),
                         ],
                     },
                 )
@@ -411,5 +477,11 @@ class VQGAN(L.LightningModule):
                     self.global_step,
                     sample_rate=self.sampling_rate,
                 )
+                self.logger.experiment.add_audio(
+                    f"sample-{idx}/wavs/posterior",
+                    post_audio[0, :audio_len],
+                    self.global_step,
+                    sample_rate=self.sampling_rate,
+                )
 
             plt.close(image_mels)

+ 161 - 36
fish_speech/models/vqgan/modules.py

@@ -6,7 +6,8 @@ from encodec.quantization.core_vq import VectorQuantization
 from torch import nn
 from torch.nn import Conv1d, Conv2d, ConvTranspose1d
 from torch.nn import functional as F
-from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
+from torch.nn.utils.parametrizations import spectral_norm, weight_norm
+from torch.nn.utils.parametrize import remove_parametrizations
 
 from fish_speech.models.vqgan.utils import (
     convert_pad_shape,
@@ -18,24 +19,138 @@ from fish_speech.models.vqgan.utils import (
 LRELU_SLOPE = 0.1
 
 
+class ResidualCouplingBlock(nn.Module):
+    def __init__(
+        self,
+        channels,
+        hidden_channels,
+        kernel_size: int = 5,
+        dilation_rate: int = 1,
+        n_layers: int = 4,
+        n_flows: int = 4,
+        gin_channels: int = 512,
+    ):
+        super().__init__()
+        self.channels = channels
+        self.hidden_channels = hidden_channels
+        self.kernel_size = kernel_size
+        self.dilation_rate = dilation_rate
+        self.n_layers = n_layers
+        self.n_flows = n_flows
+        self.gin_channels = gin_channels
+
+        self.flows = nn.ModuleList()
+        for i in range(n_flows):
+            self.flows.append(
+                ResidualCouplingLayer(
+                    channels,
+                    hidden_channels,
+                    kernel_size,
+                    dilation_rate,
+                    n_layers,
+                    gin_channels=gin_channels,
+                    mean_only=True,
+                )
+            )
+            self.flows.append(Flip())
+
+    def forward(self, x, x_mask, g=None, reverse=False):
+        if not reverse:
+            for flow in self.flows:
+                x, _ = flow(x, x_mask, g=g, reverse=reverse)
+        else:
+            for flow in reversed(self.flows):
+                x = flow(x, x_mask, g=g, reverse=reverse)
+        return x
+
+
+class Flip(nn.Module):
+    def forward(self, x, *args, reverse=False, **kwargs):
+        x = torch.flip(x, [1])
+        if not reverse:
+            logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
+            return x, logdet
+        else:
+            return x
+
+
+class ResidualCouplingLayer(nn.Module):
+    def __init__(
+        self,
+        channels,
+        hidden_channels,
+        kernel_size,
+        dilation_rate,
+        n_layers,
+        p_dropout=0,
+        gin_channels=0,
+        mean_only=False,
+    ):
+        assert channels % 2 == 0, "channels should be divisible by 2"
+        super().__init__()
+        self.channels = channels
+        self.hidden_channels = hidden_channels
+        self.kernel_size = kernel_size
+        self.dilation_rate = dilation_rate
+        self.n_layers = n_layers
+        self.half_channels = channels // 2
+        self.mean_only = mean_only
+
+        self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
+        self.enc = WaveNet(
+            hidden_channels,
+            kernel_size,
+            dilation_rate,
+            n_layers,
+            p_dropout=p_dropout,
+            gin_channels=gin_channels,
+        )
+        self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
+        self.post.weight.data.zero_()
+        self.post.bias.data.zero_()
+
+    def forward(self, x, x_mask, g=None, reverse=False):
+        x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+        h = self.pre(x0) * x_mask
+        h = self.enc(h, x_mask, g=g)
+        stats = self.post(h) * x_mask
+        if not self.mean_only:
+            m, logs = torch.split(stats, [self.half_channels] * 2, 1)
+        else:
+            m = stats
+            logs = torch.zeros_like(m)
+
+        if not reverse:
+            x1 = m + x1 * torch.exp(logs) * x_mask
+            x = torch.cat([x0, x1], 1)
+            logdet = torch.sum(logs, [1, 2])
+            return x, logdet
+        else:
+            x1 = (x1 - m) * torch.exp(-logs) * x_mask
+            x = torch.cat([x0, x1], 1)
+            return x
+
+
 @dataclass
-class VQEncoderOutput:
+class SemanticEncoderOutput:
     loss: torch.Tensor
     mean: torch.Tensor
     logs: torch.Tensor
+    z: torch.Tensor
 
 
-class VQEncoder(nn.Module):
+class SemanticEncoder(nn.Module):
     def __init__(
         self,
         in_channels: int = 1024,
-        channels: int = 384,
+        hidden_channels: int = 384,
         out_channels: int = 192,
         num_heads: int = 2,
         num_layers: int = 8,
         input_downsample: bool = True,
         code_book_size: int = 2048,
         freeze_vq: bool = False,
+        gin_channels: int = 512,
     ):
         super().__init__()
 
@@ -54,16 +169,17 @@ class VQEncoder(nn.Module):
         )
 
         # Init weights of in_proj to mimic the effect of avg pooling
-        torch.nn.init.normal_(
+        nn.init.normal_(
             self.in_proj.weight, mean=1 / (down_sample * in_channels), std=0.01
         )
         self.in_proj.bias.data.zero_()
 
-        self.feature_in = nn.Linear(in_channels, channels)
+        self.feature_in = nn.Linear(in_channels, hidden_channels)
+        self.g_in = nn.Conv1d(gin_channels, hidden_channels, 1)
         self.blocks = nn.ModuleList(
             [
                 TransformerBlock(
-                    channels,
+                    hidden_channels,
                     num_heads,
                     window_size=4,
                     window_heads_share=True,
@@ -75,7 +191,7 @@ class VQEncoder(nn.Module):
             ]
         )
 
-        self.out_proj = nn.Linear(channels, out_channels * 2)
+        self.out_proj = nn.Linear(hidden_channels, out_channels * 2)
 
         self.input_downsample = input_downsample
 
@@ -86,7 +202,7 @@ class VQEncoder(nn.Module):
             for p in self.vq_in.parameters():
                 p.requires_grad = False
 
-    def forward(self, x, key_padding_mask=None) -> VQEncoderOutput:
+    def forward(self, x, key_padding_mask=None, g=None) -> SemanticEncoderOutput:
         # x: (batch, seq_len, channels)
 
         assert key_padding_mask.size(1) == x.size(
@@ -100,7 +216,7 @@ class VQEncoder(nn.Module):
 
         if self.input_downsample:
             features = F.interpolate(
-                features.transpose(1, 2), scale_factor=2
+                features.transpose(1, 2), scale_factor=2, mode="nearest"
             ).transpose(1, 2)
 
         # Shape may change due to downsampling, let's cut it to the same size
@@ -111,6 +227,9 @@ class VQEncoder(nn.Module):
             key_padding_mask = key_padding_mask[:, :min_len]
 
         features = self.feature_in(features)
+        g = self.g_in(g).transpose(1, 2)
+        features = features + g
+
         for block in self.blocks:
             features = block(features, key_padding_mask=key_padding_mask)
 
@@ -118,10 +237,11 @@ class VQEncoder(nn.Module):
         stats = torch.masked_fill(stats, key_padding_mask.unsqueeze(1), 0)
         mean, logs = torch.chunk(stats, 2, dim=1)
 
-        return VQEncoderOutput(
+        return SemanticEncoderOutput(
             loss=loss,
             mean=mean,
             logs=logs,
+            z=mean + torch.randn_like(mean) * torch.exp(logs) * 0.5,
         )
 
 
@@ -173,8 +293,9 @@ class WaveNet(nn.Module):
             else:
                 res_skip_channels = hidden_channels
 
-            res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
-            res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
+            res_skip_layer = weight_norm(
+                nn.Conv1d(hidden_channels, res_skip_channels, 1), name="weight"
+            )
             self.res_skip_layers.append(res_skip_layer)
 
     def forward(self, x, x_mask, g=None):
@@ -205,13 +326,13 @@ class WaveNet(nn.Module):
 
         return output * x_mask
 
-    def remove_weight_norm(self):
+    def remove_parametrizations(self):
         if self.gin_channels != 0:
-            torch.nn.utils.remove_weight_norm(self.cond_layer)
+            nn.utils.remove_parametrizations(self.cond_layer)
         for l in self.in_layers:
-            torch.nn.utils.remove_weight_norm(l)
+            nn.utils.remove_parametrizations(l)
         for l in self.res_skip_layers:
-            torch.nn.utils.remove_weight_norm(l)
+            nn.utils.remove_parametrizations(l)
 
 
 @dataclass
@@ -270,26 +391,26 @@ class SpeakerEncoder(nn.Module):
     def __init__(
         self,
         in_channels: int = 128,
-        channels: int = 192,
+        hidden_channels: int = 192,
         out_channels: int = 512,
         num_heads: int = 2,
         num_layers: int = 4,
     ) -> None:
         super().__init__()
 
-        self.query = nn.Parameter(torch.randn(1, 1, channels))
-        self.in_proj = nn.Linear(in_channels, channels)
+        self.query = nn.Parameter(torch.randn(1, 1, hidden_channels))
+        self.in_proj = nn.Linear(in_channels, hidden_channels)
         self.blocks = nn.ModuleList(
             [
                 TransformerBlock(
-                    channels,
+                    hidden_channels,
                     num_heads,
                     use_relative_attn=False,
                 )
                 for _ in range(num_layers)
             ]
         )
-        self.out_proj = nn.Linear(channels, out_channels)
+        self.out_proj = nn.Linear(hidden_channels, out_channels)
 
     def forward(self, mels, mels_key_padding_mask=None):
         x = self.in_proj(mels)
@@ -305,10 +426,9 @@ class SpeakerEncoder(nn.Module):
         for block in self.blocks:
             x = block(x, key_padding_mask=mels_key_padding_mask)
 
-        x = x[:, :1]
-        x = self.out_proj(x)
+        x = self.out_proj(x[:, 0])
 
-        return x.transpose(1, 2)
+        return x
 
 
 class TransformerBlock(nn.Module):
@@ -616,7 +736,7 @@ class RelativeAttention(nn.Module):
         return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
 
 
-class ResBlock1(torch.nn.Module):
+class ResBlock1(nn.Module):
     def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
         super(ResBlock1, self).__init__()
         self.convs1 = nn.ModuleList(
@@ -706,14 +826,14 @@ class ResBlock1(torch.nn.Module):
             x = x * x_mask
         return x
 
-    def remove_weight_norm(self):
+    def remove_parametrizations(self):
         for l in self.convs1:
-            remove_weight_norm(l)
+            remove_parametrizations(l)
         for l in self.convs2:
-            remove_weight_norm(l)
+            remove_parametrizations(l)
 
 
-class ResBlock2(torch.nn.Module):
+class ResBlock2(nn.Module):
     def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
         super(ResBlock2, self).__init__()
         self.convs = nn.ModuleList(
@@ -753,9 +873,9 @@ class ResBlock2(torch.nn.Module):
             x = x * x_mask
         return x
 
-    def remove_weight_norm(self):
+    def remove_parametrizations(self):
         for l in self.convs:
-            remove_weight_norm(l)
+            remove_parametrizations(l)
 
 
 class Generator(nn.Module):
@@ -768,6 +888,7 @@ class Generator(nn.Module):
         upsample_rates,
         upsample_initial_channel,
         upsample_kernel_sizes,
+        gin_channels,
     ):
         super(Generator, self).__init__()
         self.num_kernels = len(resblock_kernel_sizes)
@@ -802,8 +923,12 @@ class Generator(nn.Module):
         self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
         self.ups.apply(init_weights)
 
-    def forward(self, x):
+        self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
+
+    def forward(self, x, g=None):
         x = self.conv_pre(x)
+        if g is not None:
+            g = self.cond(g)
 
         for i in range(self.num_upsamples):
             x = F.leaky_relu(x, LRELU_SLOPE)
@@ -821,12 +946,12 @@ class Generator(nn.Module):
 
         return x
 
-    def remove_weight_norm(self):
+    def remove_parametrizations(self):
         print("Removing weight norm...")
         for l in self.ups:
-            remove_weight_norm(l)
+            remove_parametrizations(l)
         for l in self.resblocks:
-            l.remove_weight_norm()
+            l.remove_parametrizations()
 
 
 class DiscriminatorP(nn.Module):

+ 23 - 10
tools/vqgan/migrate_from_vits.py

@@ -25,17 +25,30 @@ def main(cfg: DictConfig):
     logger.info(f"Model loaded, restoring from {generator_ckpt}")
     generator_weights = torch.load(generator_ckpt, map_location="cpu")["model"]
 
-    # HiFiGAN
+    # Decoder
     generator_state = {
-        k[4:]: v
-        for k, v in generator_weights.items()
-        if k.startswith("dec.") and not k.startswith("dec.cond.")
+        k[4:]: v for k, v in generator_weights.items() if k.startswith("dec.")
     }
-
     logger.info(f"Found {len(generator_state)} HiFiGAN weights, restoring...")
     model.generator.load_state_dict(generator_state, strict=True)
     logger.info("Generator weights restored.")
 
+    # Posterior Encoder
+    encoder_state = {
+        k[6:]: v for k, v in generator_weights.items() if k.startswith("enc_q.")
+    }
+    logger.info(f"Found {len(encoder_state)} posterior encoder weights, restoring...")
+    model.posterior_encoder.load_state_dict(encoder_state, strict=True)
+    logger.info("Posterior encoder weights restored.")
+
+    # Flow
+    # flow_state = {
+    #     k[5:]: v for k, v in generator_weights.items() if k.startswith("flow.")
+    # }
+    # logger.info(f"Found {len(flow_state)} flow weights, restoring...")
+    # model.flow.load_state_dict(flow_state, strict=True)
+    # logger.info("Flow weights restored.")
+
     # Discriminator
     logger.info(f"Model loaded, restoring from {discriminator_ckpt}")
     discriminator_weights = torch.load(discriminator_ckpt, map_location="cpu")["model"]
@@ -48,15 +61,15 @@ def main(cfg: DictConfig):
     # Restore kmeans
     logger.info("Reset vq projection layer to mimic avg pooling")
     torch.nn.init.normal_(
-        model.encoder.in_proj.weight,
+        model.semantic_encoder.in_proj.weight,
         mean=1
         / (
-            model.encoder.in_proj.weight.shape[0]
-            * model.encoder.in_proj.weight.shape[-1]
+            model.semantic_encoder.in_proj.weight.shape[0]
+            * model.semantic_encoder.in_proj.weight.shape[-1]
         ),
         std=1e-2,
     )
-    model.encoder.in_proj.bias.data.zero_()
+    model.semantic_encoder.in_proj.bias.data.zero_()
 
     kmeans_ckpt = "results/hubert-vq-pretrain/kmeans.pt"
     kmeans_ckpt = torch.load(kmeans_ckpt, map_location="cpu")
@@ -74,7 +87,7 @@ def main(cfg: DictConfig):
         "_codebook.embed_avg": centroids.clone(),
     }
 
-    model.encoder.vq.load_state_dict(state_dict, strict=True)
+    model.semantic_encoder.vq.load_state_dict(state_dict, strict=True)
 
     torch.save(model.state_dict(), cfg.ckpt_path)
     logger.info("Done")