Lengyue 2 лет назад
Родитель
Сommit
78c9747b4c

+ 100 - 0
fish_speech/configs/hubert_vq.yaml

@@ -0,0 +1,100 @@
+defaults:
+  - base
+  - _self_
+
+project: hubert_vq
+
+# Lightning Trainer
+trainer:
+  accumulate_grad_batches: 2
+  gradient_clip_val: 1000.0  # For safety
+  gradient_clip_algorithm: 'norm'
+  precision: 32
+  max_steps: 1_000_000
+
+# Dataset Configuration
+tokenizer:
+  _target_: transformers.AutoTokenizer.from_pretrained
+  pretrained_model_name_or_path: fishaudio/speech-lm-300m
+  revision: text-pretrain-10k
+
+# Dataset Configuration
+train_dataset:
+  - _target_: fish_speech.datasets.text.TextDataset
+    repo: fishaudio/cn-hubert-25hz-vq
+    prefix: 'data/train'
+
+val_dataset:
+  _target_: fish_speech.datasets.text.TextDataset
+  repo: fishaudio/cn-hubert-25hz-vq
+  prefix: 'data/test'
+
+data:
+  _target_: fish_speech.datasets.text.TextDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
+  num_workers: 4
+  batch_size: 8
+  tokenizer: ${tokenizer}
+
+# Model Configuration
+model:
+  _target_: fish_speech.models.vqgan.VQGAN
+
+  encoder:
+    _target_: fish_speech.models.modules.VQEncoder
+    in_channels: 1024
+    channels: 192
+    num_heads: 2
+    num_feature_layers: 2
+    num_speaker_layers: 4
+    num_mixin_layers: 4
+    input_downsample: true
+    code_book_size: 2048
+    freeze_vq: false
+
+  generator:
+    _target_: fish_speech.models.modules.Generator
+    initial_channel: 192
+    resblock: "1"
+    resblock_kernel_sizes: [3, 7, 11]
+    resblock_dilation_sizes: 
+      - [1, 3, 5]
+      - [1, 3, 5]
+      - [1, 3, 5]
+    upsample_rates: [10, 8, 2, 2, 2]
+    upsample_initial_channel: 512
+    upsample_kernel_sizes: [16, 16, 8, 2, 2]
+
+  discriminator:
+    _target_: fish_speech.models.modules.EnsembleDiscriminator
+
+  mel_transform:
+    _target_: fish_speech.models.spectrogram.LogMelSpectrogram
+    sample_rate: 32000
+    n_fft: 2048
+    hop_length: 640
+    win_length: 2048
+    n_mels: 128
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 1e-4
+    betas: [0.8, 0.99]
+    eps: 1e-5
+
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.LambdaLR
+    _partial_: true
+    lr_lambda:
+      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+      _partial_: true
+      num_warmup_steps: 2000
+      num_training_steps: ${trainer.max_steps}
+      final_lr_ratio: 0.05
+  
+  # Restore from old checkpoint
+  generator_ckpt: results/hubert-vq-pretrain/rcell/G_23000.pth
+  discriminator_ckpt: results/hubert-vq-pretrain/rcell/D_23000.pth
+  kmeans_ckpt: results/hubert-vq-pretrain/rcell/kmeans_23000.pth

+ 0 - 89
fish_speech/configs/whisper_vq.yaml

@@ -1,89 +0,0 @@
-paths:
-  run_dir: results/whisper-vq
-  checkpoint_dir: ${paths.run_dir}/checkpoints
-
-hydra:
-  run:
-    dir: ${paths.run_dir}
-
-trainer:
-  _target_: lightning.fabric.Fabric
-  accelerator: gpu
-  strategy: 
-    _target_: lightning.fabric.strategies.DDPStrategy
-    static_graph: true
-
-  devices: auto
-  precision: bf16-mixed
-  loggers:
-    _target_: pytorch_lightning.loggers.TensorBoardLogger
-    save_dir: ${paths.run_dir}
-    name: tensorboard
-    version: null
-
-model:
-  _target_: fish_speech.models.whisper_vq.WhisperVQ
-  model_name_or_path: "openai/whisper-medium"
-
-  # Quantization
-  codebook_dim: 32
-  codebook_size: 4096
-  codebook_decay: 0.9
-  threshold_ema_dead_code: 0
-  use_cosine_similarity: true
-  downsample: true
-
-  # Attention
-  post_attention_depth: 2
-
-schedule:
-  batch_size: 64
-  micro_batch_size: 32
-  max_steps: 10000
-  save_interval: 2000
-  gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
-  clip_grad_norm: 2.0
-  log_interval: 50
-  eval_interval: 2000
-
-train_dataloader:
-  _target_: torch.utils.data.DataLoader
-  dataset:
-    _target_: fish_speech.datasets.whisper_vq.WhisperVQDataset
-    filelist: filelists/whisper-vq.train.filelist
-  batch_size: ${schedule.micro_batch_size}
-  num_workers: 16
-  prefetch_factor: 4
-  pin_memory: true
-  persistent_workers: true
-  shuffle: true
-  collate_fn:
-    _target_: fish_speech.datasets.whisper_vq.WhisperVQCollator
-
-valid_dataloader:
-  _target_: torch.utils.data.DataLoader
-  dataset:
-    _target_: fish_speech.datasets.whisper_vq.WhisperVQDataset
-    filelist: filelists/whisper-vq.test.filelist
-  batch_size: 16
-  num_workers: 8
-  prefetch_factor: 4
-  pin_memory: true
-  shuffle: false
-  collate_fn:
-    _target_: fish_speech.datasets.whisper_vq.WhisperVQCollator
-
-optimizer:
-  _target_: torch.optim.AdamW
-  lr: 3e-4
-  weight_decay: 0.1
-  betas: [0.9, 0.95]
-  eps: 1e-5
-
-scheduler:
-  _target_: torch.optim.lr_scheduler.LambdaLR
-  lr_lambda:
-    _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
-    _partial_: true
-    num_warmup_steps: 1000
-    num_training_steps: ${schedule.max_steps}

+ 8 - 35
fish_speech/datasets/whisper_vq.py → fish_speech/datasets/vqgan.py

@@ -5,53 +5,26 @@ import librosa
 import torch
 from torch.utils.data import Dataset
 from transformers import WhisperProcessor
-from whisper.audio import HOP_LENGTH, load_audio, log_mel_spectrogram, pad_or_trim
 
 
-class WhisperVQDataset(Dataset):
+class VQGANDataset(Dataset):
     def __init__(
-        self, filelist: str, model_name_or_path: str = "openai/whisper-medium"
+        self,
+        filelist: str,
+        sample_rate: int = 32000,
     ):
         super().__init__()
 
-        self.files = [
-            Path(line.strip()) for line in Path(filelist).read_text().splitlines()
-        ]
-        self.processor = WhisperProcessor.from_pretrained(model_name_or_path)
+        filelist = Path(filelist)
+        root = filelist.parent
+
+        self.files = [root / line.strip() for line in filelist.read_text().splitlines()]
 
     def __len__(self):
         return len(self.files)
 
     def __getitem__(self, idx):
         file = self.files[idx]
-        wav = load_audio(file)
-        wav_length = wav.shape[-1]
-        mel_length = wav_length // HOP_LENGTH + 1
-
-        wav = pad_or_trim(wav)
-        wav = torch.from_numpy(wav).float()
-        input_features = log_mel_spectrogram(wav)
-        mel_mask = torch.zeros(input_features.shape[1], dtype=torch.float)
-        mel_mask[:mel_length] = 1
-
-        input_ids = file.with_suffix(".whisper.txt").read_text().strip().split("\t")[0]
-        input_ids = [int(x) for x in input_ids.split(",")]
-
-        while input_ids[-1] in [
-            self.processor.tokenizer.pad_token_id,
-            self.processor.tokenizer.eos_token_id,
-        ]:
-            input_ids.pop()
-
-        input_ids.append(self.processor.tokenizer.eos_token_id)
-        input_ids = torch.tensor(input_ids, dtype=torch.long)
-
-        return {
-            "input_values": wav,
-            "input_features": input_features,
-            "input_ids": input_ids,
-            "mel_mask": mel_mask,
-        }
 
 
 @dataclass

+ 0 - 163
fish_speech/models/hubert_vq/utils.py

@@ -1,163 +0,0 @@
-import torch
-import torch.utils.data
-from librosa.filters import mel as librosa_mel_fn
-
-
-def convert_pad_shape(pad_shape):
-    l = pad_shape[::-1]
-    pad_shape = [item for sublist in l for item in sublist]
-    return pad_shape
-
-
-def sequence_mask(length, max_length=None):
-    if max_length is None:
-        max_length = length.max()
-    x = torch.arange(max_length, dtype=length.dtype, device=length.device)
-    return x.unsqueeze(0) < length.unsqueeze(1)
-
-
-def init_weights(m, mean=0.0, std=0.01):
-    classname = m.__class__.__name__
-    if classname.find("Conv") != -1:
-        m.weight.data.normal_(mean, std)
-
-
-def get_padding(kernel_size, dilation=1):
-    return int((kernel_size * dilation - dilation) / 2)
-
-
-def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
-    """
-    PARAMS
-    ------
-    C: compression factor
-    """
-    return torch.log(torch.clamp(x, min=clip_val) * C)
-
-
-def dynamic_range_decompression_torch(x, C=1):
-    """
-    PARAMS
-    ------
-    C: compression factor used to compress
-    """
-    return torch.exp(x) / C
-
-
-def spectral_normalize_torch(magnitudes):
-    output = dynamic_range_compression_torch(magnitudes)
-    return output
-
-
-def spectral_de_normalize_torch(magnitudes):
-    output = dynamic_range_decompression_torch(magnitudes)
-    return output
-
-
-mel_basis = {}
-hann_window = {}
-
-
-def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
-    if torch.min(y) < -1.0:
-        print("min value is ", torch.min(y))
-    if torch.max(y) > 1.0:
-        print("max value is ", torch.max(y))
-
-    global hann_window
-    dtype_device = str(y.dtype) + "_" + str(y.device)
-    wnsize_dtype_device = str(win_size) + "_" + dtype_device
-    if wnsize_dtype_device not in hann_window:
-        hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
-            dtype=y.dtype, device=y.device
-        )
-
-    y = torch.nn.functional.pad(
-        y.unsqueeze(1),
-        (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
-        mode="reflect",
-    )
-    y = y.squeeze(1)
-    spec = torch.stft(
-        y,
-        n_fft,
-        hop_length=hop_size,
-        win_length=win_size,
-        window=hann_window[wnsize_dtype_device],
-        center=center,
-        pad_mode="reflect",
-        normalized=False,
-        onesided=True,
-        return_complex=False,
-    )
-
-    spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
-    return spec
-
-
-def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
-    global mel_basis
-    dtype_device = str(spec.dtype) + "_" + str(spec.device)
-    fmax_dtype_device = str(fmax) + "_" + dtype_device
-    if fmax_dtype_device not in mel_basis:
-        mel = librosa_mel_fn(
-            sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
-        )
-        mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
-            dtype=spec.dtype, device=spec.device
-        )
-    spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
-    spec = spectral_normalize_torch(spec)
-    return spec
-
-
-def mel_spectrogram_torch(
-    y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
-):
-    if torch.min(y) < -1.0:
-        print("min value is ", torch.min(y))
-    if torch.max(y) > 1.0:
-        print("max value is ", torch.max(y))
-
-    global mel_basis, hann_window
-    dtype_device = str(y.dtype) + "_" + str(y.device)
-    fmax_dtype_device = str(fmax) + "_" + dtype_device
-    wnsize_dtype_device = str(win_size) + "_" + dtype_device
-    if fmax_dtype_device not in mel_basis:
-        mel = librosa_mel_fn(
-            sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
-        )
-        mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
-            dtype=y.dtype, device=y.device
-        )
-    if wnsize_dtype_device not in hann_window:
-        hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
-            dtype=y.dtype, device=y.device
-        )
-
-    y = torch.nn.functional.pad(
-        y.unsqueeze(1),
-        (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
-        mode="reflect",
-    )
-    y = y.squeeze(1)
-
-    spec = torch.stft(
-        y,
-        n_fft,
-        hop_length=hop_size,
-        win_length=win_size,
-        window=hann_window[wnsize_dtype_device],
-        center=center,
-        pad_mode="reflect",
-        normalized=False,
-        onesided=True,
-        return_complex=False,
-    )
-
-    spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
-
-    spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
-    spec = spectral_normalize_torch(spec)
-
-    return spec

+ 3 - 0
fish_speech/models/vqgan/__init__.py

@@ -0,0 +1,3 @@
+from .lit_module import VQGAN
+
+__all__ = ["VQGAN"]

+ 11 - 53
fish_speech/models/hubert_vq/lit_module.py → fish_speech/models/vqgan/lit_module.py

@@ -1,70 +1,41 @@
 from typing import Any, Callable
 
+import lightning as L
 import torch
 import torch.nn.functional as F
-from fish_vocoder.models.vocoder import VocoderModel
-from fish_vocoder.modules.losses.stft import MultiResolutionSTFTLoss
-from fish_vocoder.utils.grad_norm import grad_norm
-from fish_vocoder.utils.mask import sequence_mask
 from torch import nn
 from torch.utils.checkpoint import checkpoint as gradient_checkpointing
 
 
-class GANModel(VocoderModel):
+class VQGAN(L.LightningModule):
     def __init__(
         self,
-        sampling_rate: int,
-        n_fft: int,
-        hop_length: int,
-        win_length: int,
-        num_mels: int,
         optimizer: Callable,
         lr_scheduler: Callable,
-        mel_transforms: nn.ModuleDict,
+        encoder: nn.Module,
         generator: nn.Module,
-        discriminators: nn.ModuleDict,
-        multi_resolution_stft_loss: MultiResolutionSTFTLoss,
-        num_frames: int,
-        crop_length: int | None = None,
-        checkpointing: bool = False,
-        feature_matching: bool = False,
+        discriminator: nn.Module,
+        mel_transform: nn.Module,
+        segment_size: int = 20480,
     ):
-        super().__init__(
-            sampling_rate=sampling_rate,
-            n_fft=n_fft,
-            hop_length=hop_length,
-            win_length=win_length,
-            num_mels=num_mels,
-        )
+        super().__init__()
 
         # Model parameters
         self.optimizer_builder = optimizer
         self.lr_scheduler_builder = lr_scheduler
 
-        # Spectrogram transforms
-        self.mel_transforms = mel_transforms
-
         # Generator and discriminators
         # Compile generator so that snake can save memory
         self.generator = generator
-        self.discriminators = discriminators
-
-        # Loss
-        self.multi_resolution_stft_loss = multi_resolution_stft_loss
+        self.discriminator = discriminator
+        self.mel_transform = mel_transform
 
         # Crop length for saving memory
-        self.num_frames = num_frames
-        self.crop_length = crop_length
+        self.segment_size = segment_size
 
         # Disable automatic optimization
         self.automatic_optimization = False
 
-        # Gradient checkpointing
-        self.checkpointing = checkpointing
-
-        # Feature matching
-        self.feature_matching = feature_matching
-
     def configure_optimizers(self):
         # Need two optimizers and two schedulers
         optimizer_generator = self.optimizer_builder(self.generator.parameters())
@@ -95,12 +66,7 @@ class GANModel(VocoderModel):
         )
 
     def training_generator(self, audio, audio_mask):
-        if self.training and self.checkpointing:
-            fake_audio, base_loss = gradient_checkpointing(
-                self.forward, audio, audio_mask, use_reentrant=False
-            )
-        else:
-            fake_audio, base_loss = self.forward(audio, audio_mask)
+        # fake_audio, base_loss = self.forward(audio, audio_mask)
 
         assert fake_audio.shape == audio.shape
 
@@ -308,14 +274,6 @@ class GANModel(VocoderModel):
         scheduler_g.step()
         scheduler_d.step()
 
-    def forward(self, audio, mask=None, input_spec=None):
-        if input_spec is None:
-            input_spec = self.mel_transforms.input(audio.squeeze(1))
-
-        fake_audio = self.generator(input_spec)
-
-        return fake_audio, 0
-
     def validation_step(self, batch: Any, batch_idx: int):
         audio, lengths = batch["audio"], batch["lengths"]
         audio_mask = sequence_mask(lengths)[:, None, :].to(audio.device, torch.float32)

+ 2 - 12
fish_speech/models/hubert_vq/modules.py → fish_speech/models/vqgan/modules.py

@@ -8,11 +8,7 @@ from torch.nn import functional as F
 from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
 from vector_quantize_pytorch import VectorQuantize
 
-from fish_speech.models.hubert_vq.utils import (
-    convert_pad_shape,
-    get_padding,
-    init_weights,
-)
+from fish_speech.models.vqgan.utils import convert_pad_shape, get_padding, init_weights
 
 LRELU_SLOPE = 0.1
 
@@ -603,7 +599,6 @@ class Generator(nn.Module):
         upsample_rates,
         upsample_initial_channel,
         upsample_kernel_sizes,
-        gin_channels=0,
     ):
         super(Generator, self).__init__()
         self.num_kernels = len(resblock_kernel_sizes)
@@ -638,13 +633,8 @@ class Generator(nn.Module):
         self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
         self.ups.apply(init_weights)
 
-        if gin_channels != 0:
-            self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
-
-    def forward(self, x, g=None):
+    def forward(self, x):
         x = self.conv_pre(x)
-        if g is not None:
-            x = x + self.cond(g)
 
         for i in range(self.num_upsamples):
             x = F.leaky_relu(x, LRELU_SLOPE)

+ 104 - 0
fish_speech/models/vqgan/spectrogram.py

@@ -0,0 +1,104 @@
+import torch
+from torch import Tensor, nn
+from torchaudio.transforms import MelScale
+
+
+class LinearSpectrogram(nn.Module):
+    def __init__(
+        self,
+        n_fft=2048,
+        win_length=2048,
+        hop_length=512,
+        center=False,
+        mode="pow2_sqrt",
+    ):
+        super().__init__()
+
+        self.n_fft = n_fft
+        self.win_length = win_length
+        self.hop_length = hop_length
+        self.center = center
+        self.mode = mode
+
+        self.register_buffer("window", torch.hann_window(win_length))
+
+    def forward(self, y: Tensor) -> Tensor:
+        if y.ndim == 3:
+            y = y.squeeze(1)
+
+        y = torch.nn.functional.pad(
+            y.unsqueeze(1),
+            (
+                (self.win_length - self.hop_length) // 2,
+                (self.win_length - self.hop_length + 1) // 2,
+            ),
+            mode="reflect",
+        ).squeeze(1)
+
+        spec = torch.stft(
+            y,
+            self.n_fft,
+            hop_length=self.hop_length,
+            win_length=self.win_length,
+            window=self.window,
+            center=self.center,
+            pad_mode="reflect",
+            normalized=False,
+            onesided=True,
+            return_complex=True,
+        )
+
+        spec = torch.view_as_real(spec)
+
+        if self.mode == "pow2_sqrt":
+            spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+
+        return spec
+
+
+class LogMelSpectrogram(nn.Module):
+    def __init__(
+        self,
+        sample_rate=44100,
+        n_fft=2048,
+        win_length=2048,
+        hop_length=512,
+        n_mels=128,
+        center=False,
+        f_min=0.0,
+        f_max=None,
+    ):
+        super().__init__()
+
+        self.sample_rate = sample_rate
+        self.n_fft = n_fft
+        self.win_length = win_length
+        self.hop_length = hop_length
+        self.center = center
+        self.n_mels = n_mels
+        self.f_min = f_min
+        self.f_max = f_max or sample_rate // 2
+
+        self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
+        self.mel_scale = MelScale(
+            self.n_mels,
+            self.sample_rate,
+            self.f_min,
+            self.f_max,
+            self.n_fft // 2 + 1,
+            "slaney",
+            "slaney",
+        )
+
+    def compress(self, x: Tensor) -> Tensor:
+        return torch.log(torch.clamp(x, min=1e-5))
+
+    def decompress(self, x: Tensor) -> Tensor:
+        return torch.exp(x)
+
+    def forward(self, x: Tensor) -> Tensor:
+        x = self.spectrogram(x)
+        x = self.mel_scale(x)
+        x = self.compress(x)
+
+        return x

+ 26 - 0
fish_speech/models/vqgan/utils.py

@@ -0,0 +1,26 @@
+import torch
+import torch.utils.data
+from librosa.filters import mel as librosa_mel_fn
+
+
+def convert_pad_shape(pad_shape):
+    l = pad_shape[::-1]
+    pad_shape = [item for sublist in l for item in sublist]
+    return pad_shape
+
+
+def sequence_mask(length, max_length=None):
+    if max_length is None:
+        max_length = length.max()
+    x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+    return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+def init_weights(m, mean=0.0, std=0.01):
+    classname = m.__class__.__name__
+    if classname.find("Conv") != -1:
+        m.weight.data.normal_(mean, std)
+
+
+def get_padding(kernel_size, dilation=1):
+    return int((kernel_size * dilation - dilation) / 2)

+ 2 - 2
tools/calculate_hubert_features.py → tools/vqgan/calculate_hubert_features.py

@@ -145,8 +145,8 @@ def main(folder: str, num_workers: int):
     begin_time = time.time()
     processed_files = 0
 
-    for n_batch, idx in enumerate(range(0, len(files), 64)):
-        batch = files[idx : idx + 64]
+    for n_batch, idx in enumerate(range(0, len(files), 32)):
+        batch = files[idx : idx + 32]
         batch_time = process_batch(batch)
         total_time += batch_time
         processed_files += len(batch)