Lengyue před 1 rokem
rodič
revize
fdc5712f9c

+ 119 - 0
fish_speech/configs/vits_decoder.yaml

@@ -0,0 +1,119 @@
+defaults:
+  - base
+  - _self_
+
+project: vits_decoder
+ckpt_path: checkpoints/Bert-VITS2/ensemble.pth
+resume_weights_only: true
+
+# Lightning Trainer
+trainer:
+  accelerator: gpu
+  devices: auto
+  strategy: ddp_find_unused_parameters_true
+  precision: 32
+  max_steps: 1_000_000
+  val_check_interval: 2000
+
+sample_rate: 44100
+hop_length: 512
+num_mels: 128
+n_fft: 2048
+win_length: 2048
+
+# Dataset Configuration
+tokenizer:
+  _target_: transformers.AutoTokenizer.from_pretrained
+  pretrained_model_name_or_path: fishaudio/fish-speech-1
+
+# Dataset Configuration
+train_dataset:
+  _target_: fish_speech.datasets.vits.VITSDataset
+  filelist: data/source/Genshin/filelist.train.txt
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  suffix: ".lab"
+  tokenizer: ${tokenizer}
+
+val_dataset:
+  _target_: fish_speech.datasets.vits.VITSDataset
+  filelist: data/source/Genshin/filelist.test.txt
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  suffix: ".lab"
+  tokenizer: ${tokenizer}
+
+data:
+  _target_: fish_speech.datasets.vits.VITSDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 4
+  batch_size: 8
+  val_batch_size: 4
+  tokenizer: ${tokenizer}
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.vits_decoder.VITSDecoder
+  sample_rate: ${sample_rate}
+  hop_length: ${hop_length}
+  freeze_discriminator: false
+
+  weight_mel: 45.0
+  weight_kl: 1.0
+
+  generator:
+    _target_: fish_speech.models.vits_decoder.modules.models.SynthesizerTrn
+    spec_channels: 1025
+    segment_size: 32
+    inter_channels: 192
+    hidden_channels: 192
+    filter_channels: 768
+    n_heads: 2
+    n_layers: 6
+    kernel_size: 3
+    p_dropout: 0.1
+    resblock: "1"
+    resblock_kernel_sizes: [3, 7, 11]
+    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
+    upsample_rates: [8, 8, 2, 2, 2]
+    upsample_initial_channel: 512
+    upsample_kernel_sizes: [16, 16, 8, 2, 2]
+    gin_channels: 512
+
+  discriminator:
+    _target_: fish_speech.models.vits_decoder.modules.models.EnsembledDiscriminator
+    periods: [2, 3, 5, 7, 11]
+
+  mel_transform:
+    _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
+    sample_rate: ${sample_rate}
+    n_fft: ${n_fft}
+    hop_length: ${hop_length}
+    win_length: ${win_length}
+    n_mels: ${num_mels}
+
+  spec_transform:
+    _target_: fish_speech.utils.spectrogram.LinearSpectrogram
+    n_fft: ${n_fft}
+    hop_length: ${hop_length}
+    win_length: ${win_length}
+    mode: pow2_sqrt
+  
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 1e-4
+    betas: [0.8, 0.99]
+    eps: 1e-6
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.ExponentialLR
+    _partial_: true
+    gamma: 0.999875
+
+callbacks:
+  grad_norm_monitor:
+    sub_module: 
+      - generator
+      - discriminator

+ 2 - 2
fish_speech/configs/vqgan_finetune.yaml

@@ -86,7 +86,7 @@ model:
     ckpt_path: null # You may download the pretrained vocoder and set the path here
     ckpt_path: null # You may download the pretrained vocoder and set the path here
 
 
   encode_mel_transform:
   encode_mel_transform:
-    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
     sample_rate: ${sample_rate}
     sample_rate: ${sample_rate}
     n_fft: ${n_fft}
     n_fft: ${n_fft}
     hop_length: ${hop_length}
     hop_length: ${hop_length}
@@ -96,7 +96,7 @@ model:
     f_max: 8000.0
     f_max: 8000.0
 
 
   gt_mel_transform:
   gt_mel_transform:
-    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
     sample_rate: ${sample_rate}
     sample_rate: ${sample_rate}
     n_fft: ${n_fft}
     n_fft: ${n_fft}
     hop_length: ${hop_length}
     hop_length: ${hop_length}

+ 2 - 2
fish_speech/configs/vqgan_pretrain.yaml

@@ -89,7 +89,7 @@ model:
     ckpt_path: null # You may download the pretrained vocoder and set the path here
     ckpt_path: null # You may download the pretrained vocoder and set the path here
 
 
   encode_mel_transform:
   encode_mel_transform:
-    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
     sample_rate: ${sample_rate}
     sample_rate: ${sample_rate}
     n_fft: ${n_fft}
     n_fft: ${n_fft}
     hop_length: ${hop_length}
     hop_length: ${hop_length}
@@ -99,7 +99,7 @@ model:
     f_max: 8000.0
     f_max: 8000.0
 
 
   gt_mel_transform:
   gt_mel_transform:
-    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
+    _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
     sample_rate: ${sample_rate}
     sample_rate: ${sample_rate}
     n_fft: ${n_fft}
     n_fft: ${n_fft}
     hop_length: ${hop_length}
     hop_length: ${hop_length}

+ 180 - 0
fish_speech/datasets/vits.py

@@ -0,0 +1,180 @@
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+import librosa
+import numpy as np
+import torch
+import torch.distributed as dist
+from lightning import LightningDataModule
+from torch.utils.data import DataLoader, Dataset
+from torch.utils.data.distributed import DistributedSampler
+from transformers import AutoTokenizer
+
+from fish_speech.utils import RankedLogger
+
+logger = RankedLogger(__name__, rank_zero_only=False)
+
+
+class VITSDataset(Dataset):
+    def __init__(
+        self,
+        filelist: str,
+        tokenizer: AutoTokenizer,
+        sample_rate: int = 44100,
+        hop_length: int = 512,
+        min_duration: float = 1.5,
+        max_duration: float = 30.0,
+        suffix: str = ".lab",
+    ):
+        super().__init__()
+
+        filelist = Path(filelist)
+        root = filelist.parent
+
+        self.files = []
+        for line in filelist.read_text(encoding="utf-8").splitlines():
+            path = root / line
+            self.files.append(path)
+
+        self.sample_rate = sample_rate
+        self.hop_length = hop_length
+        self.min_duration = min_duration
+        self.max_duration = max_duration
+        self.tokenizer = tokenizer
+        self.suffix = suffix
+
+    def __len__(self):
+        return len(self.files)
+
+    def get_item(self, idx):
+        audio_file = self.files[idx]
+        text_file = audio_file.with_suffix(self.suffix)
+
+        if text_file.exists() is False or audio_file.exists() is False:
+            return None
+
+        audio, _ = librosa.load(audio_file, sr=self.sample_rate, mono=True)
+        duration = len(audio) / self.sample_rate
+
+        if (
+            len(audio) == 0
+            or duration < self.min_duration
+            or duration > self.max_duration
+        ):
+            return None
+
+        max_value = np.abs(audio).max()
+        if max_value > 1.0:
+            audio = audio / max_value
+
+        text = text_file.read_text(encoding="utf-8")
+        input_ids = self.tokenizer(text, return_tensors="pt").input_ids.squeeze(0)
+
+        return {
+            "audio": torch.from_numpy(audio),
+            "text": input_ids,
+        }
+
+    def __getitem__(self, idx):
+        try:
+            return self.get_item(idx)
+        except Exception as e:
+            import traceback
+
+            traceback.print_exc()
+            logger.error(f"Error loading {self.files[idx]}: {e}")
+            return None
+
+
+@dataclass
+class VITSCollator:
+    tokenizer: AutoTokenizer
+
+    def __call__(self, batch):
+        batch = [x for x in batch if x is not None]
+
+        audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
+        audio_maxlen = audio_lengths.max()
+
+        text_lengths = torch.tensor([len(x["text"]) for x in batch])
+        text_maxlen = text_lengths.max()
+
+        # Rounds up to nearest multiple of 2 (audio_lengths)
+        audios = []
+        texts = []
+        for x in batch:
+            audios.append(
+                torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
+            )
+
+            texts.append(
+                torch.nn.functional.pad(
+                    x["text"],
+                    (0, text_maxlen - len(x["text"])),
+                    value=self.tokenizer.eos_token_id,
+                )
+            )
+
+        return {
+            "audios": torch.stack(audios),
+            "audio_lengths": audio_lengths,
+            "texts": torch.stack(texts),
+            "text_lengths": text_lengths,
+        }
+
+
+class VITSDataModule(LightningDataModule):
+    def __init__(
+        self,
+        train_dataset: VITSDataset,
+        val_dataset: VITSDataset,
+        tokenizer: AutoTokenizer,
+        batch_size: int = 32,
+        num_workers: int = 4,
+        val_batch_size: Optional[int] = None,
+    ):
+        super().__init__()
+
+        self.train_dataset = train_dataset
+        self.val_dataset = val_dataset
+        self.batch_size = batch_size
+        self.val_batch_size = val_batch_size or batch_size
+        self.num_workers = num_workers
+        self.tokenizer = tokenizer
+
+    def train_dataloader(self):
+        return DataLoader(
+            self.train_dataset,
+            batch_size=self.batch_size,
+            collate_fn=VITSCollator(self.tokenizer),
+            num_workers=self.num_workers,
+            shuffle=False,
+            persistent_workers=True,
+        )
+
+    def val_dataloader(self):
+        return DataLoader(
+            self.val_dataset,
+            batch_size=self.val_batch_size,
+            collate_fn=VITSCollator(self.tokenizer),
+            num_workers=self.num_workers,
+            persistent_workers=True,
+        )
+
+
+if __name__ == "__main__":
+    tokenizer = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
+    dataset = VITSDataset(
+        "data/source/Genshin/filelist.train.txt", tokenizer=tokenizer, suffix=".lab"
+    )
+    dataloader = DataLoader(
+        dataset, batch_size=4, shuffle=False, collate_fn=VITSCollator(tokenizer)
+    )
+
+    for batch in dataloader:
+        print(batch["audios"].shape)
+        print(batch["audio_lengths"])
+        print(batch["texts"].shape)
+        print(batch["text_lengths"])
+        break

+ 3 - 0
fish_speech/models/vits_decoder/__init__.py

@@ -0,0 +1,3 @@
+from .lit_module import VITSDecoder
+
+__all__ = ["VITSDecoder"]

+ 396 - 0
fish_speech/models/vits_decoder/lit_module.py

@@ -0,0 +1,396 @@
+import itertools
+from dataclasses import dataclass
+from typing import Any, Callable, Literal, Optional
+
+import lightning as L
+import torch
+import torch.nn.functional as F
+import wandb
+from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
+from matplotlib import pyplot as plt
+from torch import nn
+
+from fish_speech.models.vits_decoder.losses import (
+    discriminator_loss,
+    feature_loss,
+    generator_loss,
+    kl_loss,
+)
+from fish_speech.models.vqgan.utils import (
+    avg_with_mask,
+    plot_mel,
+    sequence_mask,
+    slice_segments,
+)
+
+
+class VITSDecoder(L.LightningModule):
+    def __init__(
+        self,
+        optimizer: Callable,
+        lr_scheduler: Callable,
+        generator: nn.Module,
+        discriminator: nn.Module,
+        mel_transform: nn.Module,
+        spec_transform: nn.Module,
+        hop_length: int = 512,
+        sample_rate: int = 44100,
+        freeze_discriminator: bool = False,
+        weight_mel: float = 45,
+        weight_kl: float = 0.1,
+    ):
+        super().__init__()
+
+        # Model parameters
+        self.optimizer_builder = optimizer
+        self.lr_scheduler_builder = lr_scheduler
+
+        # Generator and discriminator
+        self.generator = generator
+        self.discriminator = discriminator
+        self.mel_transform = mel_transform
+        self.spec_transform = spec_transform
+        self.freeze_discriminator = freeze_discriminator
+
+        # Loss weights
+        self.weight_mel = weight_mel
+        self.weight_kl = weight_kl
+
+        # Other parameters
+        self.hop_length = hop_length
+        self.sampling_rate = sample_rate
+
+        # Disable automatic optimization
+        self.automatic_optimization = False
+
+        if self.freeze_discriminator:
+            for p in self.discriminator.parameters():
+                p.requires_grad = False
+
+    def configure_optimizers(self):
+        # Need two optimizers and two schedulers
+        optimizer_generator = self.optimizer_builder(self.generator.parameters())
+        optimizer_discriminator = self.optimizer_builder(
+            self.discriminator.parameters()
+        )
+
+        lr_scheduler_generator = self.lr_scheduler_builder(optimizer_generator)
+        lr_scheduler_discriminator = self.lr_scheduler_builder(optimizer_discriminator)
+
+        return (
+            {
+                "optimizer": optimizer_generator,
+                "lr_scheduler": {
+                    "scheduler": lr_scheduler_generator,
+                    "interval": "step",
+                    "name": "optimizer/generator",
+                },
+            },
+            {
+                "optimizer": optimizer_discriminator,
+                "lr_scheduler": {
+                    "scheduler": lr_scheduler_discriminator,
+                    "interval": "step",
+                    "name": "optimizer/discriminator",
+                },
+            },
+        )
+
+    def training_step(self, batch, batch_idx):
+        optim_g, optim_d = self.optimizers()
+
+        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
+        texts, text_lengths = batch["texts"], batch["text_lengths"]
+
+        audios = audios.float()
+        audios = audios[:, None, :]
+
+        with torch.no_grad():
+            gt_mels = self.mel_transform(audios)
+            gt_specs = self.spec_transform(audios)
+
+        spec_lengths = audio_lengths // self.hop_length
+        spec_masks = torch.unsqueeze(
+            sequence_mask(spec_lengths, gt_mels.shape[2]), 1
+        ).to(gt_mels.dtype)
+
+        (
+            fake_audios,
+            ids_slice,
+            y_mask,
+            (z, z_p, m_p, logs_p, m_q, logs_q),
+        ) = self.generator(
+            audios,
+            audio_lengths,
+            gt_specs,
+            spec_lengths,
+            texts,
+            text_lengths,
+        )
+
+        gt_mels = slice_segments(gt_mels, ids_slice, self.generator.segment_size)
+        spec_masks = slice_segments(spec_masks, ids_slice, self.generator.segment_size)
+        audios = slice_segments(
+            audios,
+            ids_slice * self.hop_length,
+            self.generator.segment_size * self.hop_length,
+        )
+        fake_mels = self.mel_transform(fake_audios.squeeze(1))
+
+        assert (
+            audios.shape == fake_audios.shape
+        ), f"{audios.shape} != {fake_audios.shape}"
+
+        # Discriminator
+        if self.freeze_discriminator is False:
+            y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(
+                audios, fake_audios.detach()
+            )
+
+            with torch.autocast(device_type=audios.device.type, enabled=False):
+                loss_disc, _, _ = discriminator_loss(y_d_hat_r, y_d_hat_g)
+
+            self.log(
+                f"train/discriminator/loss",
+                loss_disc,
+                on_step=True,
+                on_epoch=False,
+                prog_bar=False,
+                logger=True,
+                sync_dist=True,
+            )
+
+            optim_d.zero_grad()
+            self.manual_backward(loss_disc)
+            self.clip_gradients(
+                optim_d, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
+            )
+            optim_d.step()
+
+        # Adv Loss
+        y_d_hat_r, y_d_hat_g, _, _ = self.discriminator(audios, fake_audios)
+
+        # Adversarial Loss
+        with torch.autocast(device_type=audios.device.type, enabled=False):
+            loss_adv, _ = generator_loss(y_d_hat_g)
+
+        self.log(
+            f"train/generator/adv",
+            loss_adv,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
+
+        with torch.autocast(device_type=audios.device.type, enabled=False):
+            loss_fm = feature_loss(y_d_hat_r, y_d_hat_g)
+
+        self.log(
+            f"train/generator/adv_fm",
+            loss_fm,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
+
+        with torch.autocast(device_type=audios.device.type, enabled=False):
+            loss_mel = avg_with_mask(
+                F.l1_loss(gt_mels, fake_mels, reduction="none"), spec_masks
+            )
+
+        self.log(
+            "train/generator/loss_mel",
+            loss_mel,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
+
+        loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, y_mask)
+
+        self.log(
+            "train/generator/loss_kl",
+            loss_kl,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
+
+        loss = (
+            loss_mel * self.weight_mel + loss_kl * self.weight_kl + loss_adv + loss_fm
+        )
+        self.log(
+            "train/generator/loss",
+            loss,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=True,
+            logger=True,
+            sync_dist=True,
+        )
+
+        # Backward
+        optim_g.zero_grad()
+
+        self.manual_backward(loss)
+        self.clip_gradients(
+            optim_g, gradient_clip_val=1000.0, gradient_clip_algorithm="norm"
+        )
+        optim_g.step()
+
+        # Manual LR Scheduler
+        scheduler_g, scheduler_d = self.lr_schedulers()
+        scheduler_g.step()
+        scheduler_d.step()
+
+    def validation_step(self, batch: Any, batch_idx: int):
+        audios, audio_lengths = batch["audios"], batch["audio_lengths"]
+        texts, text_lengths = batch["texts"], batch["text_lengths"]
+
+        audios = audios.float()
+        audios = audios[:, None, :]
+
+        gt_mels = self.mel_transform(audios)
+        gt_specs = self.spec_transform(audios)
+        spec_lengths = audio_lengths // self.hop_length
+        spec_masks = torch.unsqueeze(
+            sequence_mask(spec_lengths, gt_mels.shape[2]), 1
+        ).to(gt_mels.dtype)
+
+        prior_audios = self.generator.infer(
+            audios, audio_lengths, gt_specs, spec_lengths, texts, text_lengths
+        )
+        posterior_audios = self.generator.infer_posterior(gt_specs, spec_lengths)
+        prior_mels = self.mel_transform(prior_audios.squeeze(1))
+        posterior_mels = self.mel_transform(posterior_audios.squeeze(1))
+
+        min_mel_length = min(
+            gt_mels.shape[-1], prior_mels.shape[-1], posterior_mels.shape[-1]
+        )
+        gt_mels = gt_mels[:, :, :min_mel_length]
+        prior_mels = prior_mels[:, :, :min_mel_length]
+        posterior_mels = posterior_mels[:, :, :min_mel_length]
+
+        prior_mel_loss = avg_with_mask(
+            F.l1_loss(gt_mels, prior_mels, reduction="none"), spec_masks
+        )
+        posterior_mel_loss = avg_with_mask(
+            F.l1_loss(gt_mels, posterior_mels, reduction="none"), spec_masks
+        )
+
+        self.log(
+            "val/prior_mel_loss",
+            prior_mel_loss,
+            on_step=False,
+            on_epoch=True,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
+
+        self.log(
+            "val/posterior_mel_loss",
+            posterior_mel_loss,
+            on_step=False,
+            on_epoch=True,
+            prog_bar=False,
+            logger=True,
+            sync_dist=True,
+        )
+
+        # only log the first batch
+        if batch_idx != 0:
+            return
+
+        for idx, (
+            mel,
+            prior_mel,
+            posterior_mel,
+            audio,
+            prior_audio,
+            posterior_audio,
+            audio_len,
+        ) in enumerate(
+            zip(
+                gt_mels,
+                prior_mels,
+                posterior_mels,
+                audios.detach().float(),
+                prior_audios.detach().float(),
+                posterior_audios.detach().float(),
+                audio_lengths,
+            )
+        ):
+            mel_len = audio_len // self.hop_length
+
+            image_mels = plot_mel(
+                [
+                    prior_mel[:, :mel_len],
+                    posterior_mel[:, :mel_len],
+                    mel[:, :mel_len],
+                ],
+                [
+                    "Prior (VQ)",
+                    "Posterior (Reconstruction)",
+                    "Ground-Truth",
+                ],
+            )
+
+            if isinstance(self.logger, WandbLogger):
+                self.logger.experiment.log(
+                    {
+                        "reconstruction_mel": wandb.Image(image_mels, caption="mels"),
+                        "wavs": [
+                            wandb.Audio(
+                                audio[0, :audio_len],
+                                sample_rate=self.sampling_rate,
+                                caption="gt",
+                            ),
+                            wandb.Audio(
+                                prior_audio[0, :audio_len],
+                                sample_rate=self.sampling_rate,
+                                caption="prior",
+                            ),
+                            wandb.Audio(
+                                posterior_audio[0, :audio_len],
+                                sample_rate=self.sampling_rate,
+                                caption="posterior",
+                            ),
+                        ],
+                    },
+                )
+
+            if isinstance(self.logger, TensorBoardLogger):
+                self.logger.experiment.add_figure(
+                    f"sample-{idx}/mels",
+                    image_mels,
+                    global_step=self.global_step,
+                )
+                self.logger.experiment.add_audio(
+                    f"sample-{idx}/wavs/gt",
+                    audio[0, :audio_len],
+                    self.global_step,
+                    sample_rate=self.sampling_rate,
+                )
+                self.logger.experiment.add_audio(
+                    f"sample-{idx}/wavs/prior",
+                    prior_audio[0, :audio_len],
+                    self.global_step,
+                    sample_rate=self.sampling_rate,
+                )
+                self.logger.experiment.add_audio(
+                    f"sample-{idx}/wavs/posterior",
+                    posterior_audio[0, :audio_len],
+                    self.global_step,
+                    sample_rate=self.sampling_rate,
+                )
+
+            plt.close(image_mels)

+ 67 - 0
fish_speech/models/vits_decoder/losses.py

@@ -0,0 +1,67 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+def feature_loss(fmap_r: list[torch.Tensor], fmap_g: list[torch.Tensor]):
+    loss = 0
+    for dr, dg in zip(fmap_r, fmap_g):
+        dr = dr.float().detach()
+        dg = dg.float()
+        loss += torch.mean(torch.abs(dr - dg))
+
+    return loss * 2
+
+
+def discriminator_loss(
+    disc_real_outputs: list[torch.Tensor], disc_generated_outputs: list[torch.Tensor]
+):
+    loss = 0
+    r_losses = []
+    g_losses = []
+    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+        dr = dr.float()
+        dg = dg.float()
+        r_loss = torch.mean((1 - dr) ** 2)
+        g_loss = torch.mean(dg**2)
+        loss += r_loss + g_loss
+        r_losses.append(r_loss.item())
+        g_losses.append(g_loss.item())
+
+    return loss, r_losses, g_losses
+
+
+def generator_loss(disc_outputs: list[torch.Tensor]):
+    loss = 0
+    gen_losses = []
+    for dg in disc_outputs:
+        dg = dg.float()
+        l = torch.mean((1 - dg) ** 2)
+        gen_losses.append(l)
+        loss += l
+
+    return loss, gen_losses
+
+
+def kl_loss(
+    z_p: torch.Tensor,
+    logs_q: torch.Tensor,
+    m_p: torch.Tensor,
+    logs_p: torch.Tensor,
+    z_mask: torch.Tensor,
+):
+    """
+    z_p, logs_q: [b, h, t_t]
+    m_p, logs_p: [b, h, t_t]
+    """
+    z_p = z_p.float()
+    logs_q = logs_q.float()
+    m_p = m_p.float()
+    logs_p = logs_p.float()
+    z_mask = z_mask.float()
+
+    kl = logs_p - logs_q - 0.5
+    kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
+    kl = torch.sum(kl * z_mask)
+    l = kl / torch.sum(z_mask)
+    return l

+ 76 - 59
fish_speech/models/vits_decoder/modules/models.py

@@ -1,10 +1,6 @@
-import copy
-import math
-
 import torch
 import torch
 from torch import nn
 from torch import nn
-from torch.cuda.amp import autocast
-from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
+from torch.nn import Conv1d, Conv2d, ConvTranspose1d
 from torch.nn import functional as F
 from torch.nn import functional as F
 from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
 from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
 
 
@@ -26,6 +22,7 @@ class TextEncoder(nn.Module):
         kernel_size,
         kernel_size,
         p_dropout,
         p_dropout,
         latent_channels=192,
         latent_channels=192,
+        codebook_size=264,
     ):
     ):
         super().__init__()
         super().__init__()
         self.out_channels = out_channels
         self.out_channels = out_channels
@@ -51,9 +48,7 @@ class TextEncoder(nn.Module):
         self.encoder_text = attentions.Encoder(
         self.encoder_text = attentions.Encoder(
             hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
             hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
         )
         )
-        self.text_embedding = nn.Embedding(
-            322, hidden_channels
-        )  # We only use 264, but to make the weight happy
+        self.text_embedding = nn.Embedding(codebook_size, hidden_channels)
 
 
         self.mrte = MRTE()
         self.mrte = MRTE()
 
 
@@ -68,7 +63,7 @@ class TextEncoder(nn.Module):
 
 
         self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
         self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
 
 
-    def forward(self, y, y_lengths, text, text_lengths, ge, test=None):
+    def forward(self, y, y_lengths, text, text_lengths, ge):
         y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
         y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
             y.dtype
             y.dtype
         )
         )
@@ -80,8 +75,6 @@ class TextEncoder(nn.Module):
         text_mask = torch.unsqueeze(
         text_mask = torch.unsqueeze(
             commons.sequence_mask(text_lengths, text.size(1)), 1
             commons.sequence_mask(text_lengths, text.size(1)), 1
         ).to(y.dtype)
         ).to(y.dtype)
-        if test == 1:
-            text[:, :] = 0
         text = self.text_embedding(text).transpose(1, 2)
         text = self.text_embedding(text).transpose(1, 2)
         text = self.encoder_text(text * text_mask, text_mask)
         text = self.encoder_text(text * text_mask, text_mask)
         y = self.mrte(y, y_mask, text, text_mask, ge)
         y = self.mrte(y, y_mask, text, text_mask, ge)
@@ -388,10 +381,8 @@ class DiscriminatorS(torch.nn.Module):
 
 
 
 
 class EnsembledDiscriminator(torch.nn.Module):
 class EnsembledDiscriminator(torch.nn.Module):
-    def __init__(self, use_spectral_norm=False):
+    def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False):
         super().__init__()
         super().__init__()
-        periods = [2, 3, 5, 7, 11]
-
         discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
         discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
         discs = discs + [
         discs = discs + [
             DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
             DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
@@ -438,6 +429,7 @@ class SynthesizerTrn(nn.Module):
         upsample_initial_channel,
         upsample_initial_channel,
         upsample_kernel_sizes,
         upsample_kernel_sizes,
         gin_channels=0,
         gin_channels=0,
+        codebook_size=264,
     ):
     ):
         super().__init__()
         super().__init__()
 
 
@@ -466,6 +458,7 @@ class SynthesizerTrn(nn.Module):
             n_layers,
             n_layers,
             kernel_size,
             kernel_size,
             p_dropout,
             p_dropout,
+            codebook_size=codebook_size,
         )
         )
         self.dec = Generator(
         self.dec = Generator(
             inter_channels,
             inter_channels,
@@ -498,69 +491,80 @@ class SynthesizerTrn(nn.Module):
         for param in self.vq.parameters():
         for param in self.vq.parameters():
             param.requires_grad = False
             param.requires_grad = False
 
 
-    def forward(self, audio, audio_lengths, y, y_lengths, text, text_lengths):
-        y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
-            y.dtype
-        )
-        ge = self.ref_enc(y * y_mask, y_mask)
-        quantized = self.vq(audio, audio_lengths, sr=32000)
-
-        quantized = F.interpolate(quantized, size=int(y.shape[-1]), mode="nearest")
+    def forward(
+        self, audio, audio_lengths, gt_specs, gt_spec_lengths, text, text_lengths
+    ):
+        y_mask = torch.unsqueeze(
+            commons.sequence_mask(gt_spec_lengths, gt_specs.size(2)), 1
+        ).to(gt_specs.dtype)
+        ge = self.ref_enc(gt_specs * y_mask, y_mask)
+        quantized = self.vq(audio, audio_lengths)
+        quantized = F.interpolate(quantized, size=gt_specs.size(-1), mode="nearest")
 
 
         x, m_p, logs_p, y_mask = self.enc_p(
         x, m_p, logs_p, y_mask = self.enc_p(
-            quantized, y_lengths, text, text_lengths, ge
+            quantized, gt_spec_lengths, text, text_lengths, ge
         )
         )
-        z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
+        z, m_q, logs_q, y_mask = self.enc_q(gt_specs, gt_spec_lengths, g=ge)
         z_p = self.flow(z, y_mask, g=ge)
         z_p = self.flow(z, y_mask, g=ge)
 
 
         z_slice, ids_slice = commons.rand_slice_segments(
         z_slice, ids_slice = commons.rand_slice_segments(
-            z, y_lengths, self.segment_size
+            z, gt_spec_lengths, self.segment_size
         )
         )
         o = self.dec(z_slice, g=ge)
         o = self.dec(z_slice, g=ge)
+
         return (
         return (
             o,
             o,
             ids_slice,
             ids_slice,
             y_mask,
             y_mask,
-            y_mask,
             (z, z_p, m_p, logs_p, m_q, logs_q),
             (z, z_p, m_p, logs_p, m_q, logs_q),
         )
         )
 
 
+    @torch.no_grad()
     def infer(
     def infer(
         self,
         self,
         audio,
         audio,
         audio_lengths,
         audio_lengths,
-        y,
-        y_lengths,
+        gt_specs,
+        gt_spec_lengths,
         text,
         text,
         text_lengths,
         text_lengths,
-        test=None,
         noise_scale=0.5,
         noise_scale=0.5,
     ):
     ):
-        # y_lengths = audio_lengths // 640 # 640 is the hop size
-        y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(
-            y.dtype
-        )
-        ge = self.ref_enc(y * y_mask, y_mask)
-
-        quantized = self.vq(audio, audio_lengths, sr=32000)
-        print(quantized.size())
-        quantized = F.interpolate(
-            quantized, size=int(audio.shape[-1] // 640), mode="nearest"
-        )
-        print(quantized.size())
+        y_mask = torch.unsqueeze(
+            commons.sequence_mask(gt_spec_lengths, gt_specs.size(2)), 1
+        ).to(gt_specs.dtype)
+        ge = self.ref_enc(gt_specs * y_mask, y_mask)
+        quantized = self.vq(audio, audio_lengths)
 
 
         x, m_p, logs_p, y_mask = self.enc_p(
         x, m_p, logs_p, y_mask = self.enc_p(
-            quantized, audio_lengths, text, text_lengths, ge, test=test
+            quantized, audio_lengths, text, text_lengths, ge
         )
         )
         z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
         z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
 
 
         z = self.flow(z_p, y_mask, g=ge, reverse=True)
         z = self.flow(z_p, y_mask, g=ge, reverse=True)
 
 
-        o = self.dec((z * y_mask)[:, :, :], g=ge)
-        return o, y_mask, (z, z_p, m_p, logs_p)
+        o = self.dec(z * y_mask, g=ge)
+        return o
+
+    @torch.no_grad()
+    def infer_posterior(
+        self,
+        gt_specs,
+        gt_spec_lengths,
+    ):
+        y_mask = torch.unsqueeze(
+            commons.sequence_mask(gt_spec_lengths, gt_specs.size(2)), 1
+        ).to(gt_specs.dtype)
+        ge = self.ref_enc(gt_specs * y_mask, y_mask)
+        z, m_q, logs_q, y_mask = self.enc_q(gt_specs, gt_spec_lengths, g=ge)
+        o = self.dec(z * y_mask, g=ge)
+
+        return o
 
 
     @torch.no_grad()
     @torch.no_grad()
     def decode(self, codes, text, refer, noise_scale=0.5):
     def decode(self, codes, text, refer, noise_scale=0.5):
+        # TODO: not tested yet
+
         ge = None
         ge = None
         if refer is not None:
         if refer is not None:
             refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
             refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
@@ -587,17 +591,12 @@ class SynthesizerTrn(nn.Module):
         o = self.dec((z * y_mask)[:, :, :], g=ge)
         o = self.dec((z * y_mask)[:, :, :], g=ge)
         return o
         return o
 
 
-    def extract_latent(self, x):
-        ssl = self.ssl_proj(x)
-        quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
-        return codes.transpose(0, 1)
-
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     import librosa
     import librosa
     from transformers import AutoTokenizer
     from transformers import AutoTokenizer
 
 
-    from fish_speech.models.vqgan.spectrogram import LinearSpectrogram
+    from fish_speech.utils.spectrogram import LinearSpectrogram
 
 
     model = SynthesizerTrn(
     model = SynthesizerTrn(
         spec_channels=1025,
         spec_channels=1025,
@@ -612,28 +611,46 @@ if __name__ == "__main__":
         resblock="1",
         resblock="1",
         resblock_kernel_sizes=[3, 7, 11],
         resblock_kernel_sizes=[3, 7, 11],
         resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
         resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
-        upsample_rates=[10, 8, 2, 2, 2],
+        upsample_rates=[8, 8, 2, 2, 2],
         upsample_initial_channel=512,
         upsample_initial_channel=512,
         upsample_kernel_sizes=[16, 16, 8, 2, 2],
         upsample_kernel_sizes=[16, 16, 8, 2, 2],
         gin_channels=512,
         gin_channels=512,
     )
     )
 
 
-    ckpt = "./checkpoints/s2_big2k1_158000.pth"
+    ckpt = "checkpoints/Bert-VITS2/G_0.pth"
     # Try to load the model
     # Try to load the model
     print(f"Loading model from {ckpt}")
     print(f"Loading model from {ckpt}")
-    checkpoint = torch.load(ckpt, map_location="cpu", weights_only=True)
+    checkpoint = torch.load(ckpt, map_location="cpu", weights_only=True)["model"]
+    d_checkpoint = torch.load(
+        "checkpoints/Bert-VITS2/D_0.pth", map_location="cpu", weights_only=True
+    )["model"]
+    print(checkpoint.keys())
+
+    checkpoint.pop("dec.cond.weight")
+    checkpoint.pop("enc_q.enc.cond_layer.weight_v")
+
+    new_checkpoint = {}
+    for k, v in checkpoint.items():
+        new_checkpoint["generator." + k] = v
+
+    for k, v in d_checkpoint.items():
+        new_checkpoint["discriminator." + k] = v
+
+    torch.save(new_checkpoint, "checkpoints/Bert-VITS2/ensemble.pth")
+    exit()
+
     print(model.load_state_dict(checkpoint, strict=False))
     print(model.load_state_dict(checkpoint, strict=False))
 
 
     # Test
     # Test
 
 
-    ref_audio = librosa.load(
-        "data/source/Genshin/Chinese/五郎/vo_DQAQ010_15_gorou_07.wav", sr=32000
-    )[0]
+    ref_audio = librosa.load("data/source/云天河/云天河-旁白/《薄太太》第0025集-yth_24.wav", sr=32000)[
+        0
+    ]
     input_audio = librosa.load(
     input_audio = librosa.load(
-        "data/source/Genshin/Chinese/空/vo_FDAQ003_46_hero_02.wav", sr=32000
+        "data/source/云天河/云天河-旁白/《薄太太》第0025集-yth_24.wav", sr=32000
     )[0]
     )[0]
-    # ref_audio = input_audio
-    text = "(现在看来花瓶里的水并不是用来隐藏水迹,而是在莉莉安与考威尔的争斗中不小心撞破的…)"
+    ref_audio = input_audio
+    text = "博兴只知道身边的小女人没睡着,他又凑到她耳边压低了声线。阮苏眉睁眼,不觉得你老公像英雄吗?阮苏还是没反应,这男人是不是有病?刚才那冰冷又强势的样子,和现在这幼稚无赖的样子,根本就判若二人。"
     encoded_text = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
     encoded_text = AutoTokenizer.from_pretrained("fishaudio/fish-speech-1")
     spec = LinearSpectrogram(n_fft=2048, hop_length=640, win_length=2048)
     spec = LinearSpectrogram(n_fft=2048, hop_length=640, win_length=2048)
 
 

+ 0 - 2
fish_speech/models/vits_decoder/modules/mrte.py

@@ -11,9 +11,7 @@ class MRTE(nn.Module):
         content_enc_channels=192,
         content_enc_channels=192,
         hidden_size=512,
         hidden_size=512,
         out_channels=192,
         out_channels=192,
-        kernel_size=5,
         n_heads=4,
         n_heads=4,
-        ge_layer=2,
     ):
     ):
         super(MRTE, self).__init__()
         super(MRTE, self).__init__()
         self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)
         self.cross_attention = MultiHeadAttention(hidden_size, hidden_size, n_heads)

+ 1 - 1
fish_speech/models/vits_decoder/modules/vq_encoder.py

@@ -3,7 +3,7 @@ from torch import nn
 
 
 from fish_speech.models.vqgan.modules.fsq import DownsampleFiniteScalarQuantize
 from fish_speech.models.vqgan.modules.fsq import DownsampleFiniteScalarQuantize
 from fish_speech.models.vqgan.modules.wavenet import WaveNet
 from fish_speech.models.vqgan.modules.wavenet import WaveNet
-from fish_speech.models.vqgan.spectrogram import LogMelSpectrogram
+from fish_speech.utils.spectrogram import LogMelSpectrogram
 
 
 
 
 class VQEncoder(nn.Module):
 class VQEncoder(nn.Module):

+ 0 - 0
fish_speech/models/vqgan/spectrogram.py → fish_speech/utils/spectrogram.py