Sfoglia il codice sorgente

Remove vqgan & better data server && more data

Lengyue 2 anni fa
parent
commit
2eece7e37a

+ 9 - 3
data_server/src/main.rs

@@ -27,6 +27,7 @@ pub struct MyDataService {
 fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<TextData>> {
     let mut text_data_list = Vec::new();
     let mut index = 0;
+    let mut total_vq_frames = 0;
 
     loop {
         let mut size_buf = [0u8; 4];
@@ -43,15 +44,22 @@ fn read_pb_stream<R: Read>(mut reader: BufReader<R>) -> io::Result<Vec<TextData>
 
         let text_data = TextData::decode(&message_buf[..])
             .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
+
+        text_data.sentences.iter().for_each(|sentence| {
+            total_vq_frames += sentence.semantics[0].values.len();
+        });
+        
         text_data_list.push(text_data);
 
         index += 1;
 
         if index % 10000 == 0 {
-            info!("Loaded {} groups", index);
+            info!("Loaded {} groups, total vq frames: {}", index, total_vq_frames);
         }
     }
 
+    info!("Loaded {} groups, total vq frames: {}", index, total_vq_frames);
+
     Ok(text_data_list)
 }
 
@@ -71,8 +79,6 @@ impl MyDataService {
             }
         }
 
-        info!("Loaded {} groups", groups.len());
-
         Ok(MyDataService {
             groups,
             weights,

+ 8 - 0
fish_speech/configs/data/libri-light.yaml

@@ -0,0 +1,8 @@
+datasets:
+  - root: /***REMOVED***/workspace/eva-gan/data/libri-light
+    source: LibriLight
+    languages: [ZH, EN]
+    extension: .txt
+    # This controls the grouping of the dataset (i.e. speaker)
+    # 1 means we use the parent folder of the file as the group name
+    group_parent_level: [3]  # speaker

+ 8 - 0
fish_speech/configs/data/playerfm.yaml

@@ -0,0 +1,8 @@
+datasets:
+  - root: /***REMOVED***/workspace/eva-gan/data/playerfm/data/asr/00000/zh
+    source: PlayerFM-zh
+    languages: [ZH, EN]
+    extension: .txt
+    # This controls the grouping of the dataset (i.e. speaker)
+    # 1 means we use the parent folder of the file as the group name
+    group_parent_level: [2, 1]  # episode-speaker

+ 0 - 128
fish_speech/configs/vqgan_finetune.yaml

@@ -1,128 +0,0 @@
-defaults:
-  - base
-  - _self_
-
-project: vqgan_finetune
-ckpt_path: checkpoints/vqgan-v1.pth
-resume_weights_only: true
-
-# Lightning Trainer
-trainer:
-  accelerator: gpu
-  devices: auto
-  strategy: ddp_find_unused_parameters_true
-  precision: 32
-  max_steps: 100000
-  val_check_interval: 1000
-
-sample_rate: 22050
-hop_length: 256
-num_mels: 80
-n_fft: 1024
-win_length: 1024
-segment_size: 256
-
-# Dataset Configuration
-train_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/demo/vq_train_filelist.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-  slice_frames: ${segment_size}
-
-val_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/demo/vq_val_filelist.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-
-data:
-  _target_: fish_speech.datasets.vqgan.VQGANDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 4
-  batch_size: 32
-  val_batch_size: 4
-
-# Model Configuration
-model:
-  _target_: fish_speech.models.vqgan.VQGAN
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-
-  encoder:
-    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
-    hidden_channels: 512
-    kernel_size: 3
-    dilation_rate: 2
-    n_layers: 20
-    in_channels: ${num_mels}
-
-  vq:
-    _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
-    in_channels: 512
-    vq_channels: 512
-    codebook_size: 256
-    codebook_groups: 4
-    codebook_layers: 2
-    downsample: 4
-
-  decoder:
-    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
-    hidden_channels: 512
-    kernel_size: 3
-    dilation_rate: 2
-    n_layers: 20
-    out_channels: ${num_mels}
-
-  generator:
-    _target_: fish_speech.models.vqgan.modules.decoder_v2.HiFiGANGenerator
-    hop_length: ${hop_length}
-    upsample_rates: [8, 8, 2, 2, 2]  # aka. strides
-    upsample_kernel_sizes: [16, 16, 4, 4, 4]
-    resblock_kernel_sizes: [3, 7, 11]
-    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
-    num_mels: ${num_mels}
-    upsample_initial_channel: 512
-    use_template: true
-    pre_conv_kernel_size: 7
-    post_conv_kernel_size: 7
-    ckpt_path: checkpoints/hifi-gan-base-002000000.ckpt
-
-  discriminator:
-    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
-    hidden_channels: 256
-    kernel_size: 3
-    dilation_rate: 2
-    n_layers: 6
-    in_channels: ${num_mels}
-
-  mel_transform:
-    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
-    sample_rate: ${sample_rate}
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    n_mels: ${num_mels}
-
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 2e-5
-    betas: [0.8, 0.99]
-    eps: 1e-5
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.ExponentialLR
-    _partial_: true
-    gamma: 0.999999  # Estimated base on LibriTTS dataset
-
-callbacks:
-  grad_norm_monitor:
-    sub_module: 
-      - generator
-      - discriminator
-      - decoder
-
-  model_checkpoint:
-    every_n_train_steps: 1000

+ 0 - 138
fish_speech/configs/vqgan_pretrain.yaml

@@ -1,138 +0,0 @@
-defaults:
-  - base
-  - _self_
-
-project: vqgan_pretrain_v2_large_30
-
-# Lightning Trainer
-trainer:
-  accelerator: gpu
-  devices: auto
-  strategy: ddp_find_unused_parameters_true
-  precision: bf16-mixed
-  max_steps: 10_000_000
-  val_check_interval: 5000
-
-sample_rate: 44100
-hop_length: 512
-num_mels: 160
-n_fft: 2048
-win_length: 2048
-segment_size: 256
-
-# Dataset Configuration
-train_dataset:
-  _target_: fish_speech.datasets.vqgan.MixDatast
-  datasets:
-    high-quality-441:
-      prob: 0.5
-      dataset:
-        _target_: fish_speech.datasets.vqgan.VQGANDataset
-        filelist: data/vocoder_data_441/vq_train_filelist.txt
-        sample_rate: ${sample_rate}
-        hop_length: ${hop_length}
-        slice_frames: ${segment_size}
-    
-    common-voice:
-      prob: 0.5
-      dataset:
-        _target_: fish_speech.datasets.vqgan.VQGANDataset
-        filelist: data/cv-corpus-16.0-2023-12-06/vq_train_filelist.txt
-        sample_rate: ${sample_rate}
-        hop_length: ${hop_length}
-        slice_frames: ${segment_size}
-
-val_dataset:
-  _target_: fish_speech.datasets.vqgan.VQGANDataset
-  filelist: data/vocoder_data_441/vq_val_filelist.txt
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-
-data:
-  _target_: fish_speech.datasets.vqgan.VQGANDataModule
-  train_dataset: ${train_dataset}
-  val_dataset: ${val_dataset}
-  num_workers: 4
-  batch_size: 32
-  val_batch_size: 4
-
-# Model Configuration
-model:
-  _target_: fish_speech.models.vqgan.VQGAN
-  sample_rate: ${sample_rate}
-  hop_length: ${hop_length}
-
-  encoder:
-    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
-    hidden_channels: 512
-    kernel_size: 3
-    dilation_rate: 2
-    n_layers: 20
-    in_channels: ${num_mels}
-
-  vq:
-    _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
-    in_channels: 512
-    vq_channels: 512
-    codebook_size: 256
-    codebook_groups: 4
-    codebook_layers: 2
-    downsample: 4
-
-  decoder:
-    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
-    hidden_channels: 512
-    kernel_size: 3
-    dilation_rate: 2
-    n_layers: 20
-    out_channels: ${num_mels}
-
-  generator:
-    _target_: fish_speech.models.vqgan.modules.decoder_v2.HiFiGANGenerator
-    hop_length: ${hop_length}
-    upsample_rates: [8, 8, 2, 2, 2]  # aka. strides
-    upsample_kernel_sizes: [16, 16, 4, 4, 4]
-    resblock_kernel_sizes: [3, 7, 11]
-    resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
-    num_mels: ${num_mels}
-    upsample_initial_channel: 512
-    use_template: true
-    pre_conv_kernel_size: 7
-    post_conv_kernel_size: 7
-    ckpt_path: checkpoints/hifi-gan-base-002000000.ckpt
-
-  discriminator:
-    _target_: fish_speech.models.vqgan.modules.modules.WaveNet
-    hidden_channels: 256
-    kernel_size: 3
-    dilation_rate: 2
-    n_layers: 6
-    in_channels: ${num_mels}
-
-  mel_transform:
-    _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
-    sample_rate: ${sample_rate}
-    n_fft: ${n_fft}
-    hop_length: ${hop_length}
-    win_length: ${win_length}
-    n_mels: ${num_mels}
-
-  optimizer:
-    _target_: torch.optim.AdamW
-    _partial_: true
-    lr: 2e-4
-    betas: [0.8, 0.99]
-    eps: 1e-5
-
-  lr_scheduler:
-    _target_: torch.optim.lr_scheduler.ExponentialLR
-    _partial_: true
-    gamma: 0.999999  # Estimated base on LibriTTS dataset
-
-callbacks:
-  grad_norm_monitor:
-    sub_module: 
-      - encoder
-      - vq
-      - decoder
-      - discriminator

+ 0 - 169
fish_speech/datasets/vqgan.py

@@ -1,169 +0,0 @@
-from dataclasses import dataclass
-from pathlib import Path
-from typing import Optional
-
-import librosa
-import numpy as np
-import torch
-from lightning import LightningDataModule
-from torch.utils.data import DataLoader, Dataset, IterableDataset
-
-from fish_speech.utils import RankedLogger
-
-logger = RankedLogger(__name__, rank_zero_only=False)
-
-
-class VQGANDataset(Dataset):
-    def __init__(
-        self,
-        filelist: str,
-        sample_rate: int = 32000,
-        hop_length: int = 640,
-        slice_frames: Optional[int] = None,
-    ):
-        super().__init__()
-
-        filelist = Path(filelist)
-        root = filelist.parent
-
-        self.files = [
-            root / line.strip()
-            for line in filelist.read_text().splitlines()
-            if line.strip()
-        ]
-        self.sample_rate = sample_rate
-        self.hop_length = hop_length
-        self.slice_frames = slice_frames
-
-    def __len__(self):
-        return len(self.files)
-
-    def get_item(self, idx):
-        file = self.files[idx]
-
-        audio, _ = librosa.load(file, sr=self.sample_rate, mono=True)
-
-        # Slice audio and features
-        if (
-            self.slice_frames is not None
-            and audio.shape[0] > self.slice_frames * self.hop_length
-        ):
-            start = np.random.randint(
-                0, audio.shape[0] - self.slice_frames * self.hop_length
-            )
-            audio = audio[start : start + self.slice_frames * self.hop_length]
-
-        if len(audio) == 0:
-            return None
-
-        max_value = np.abs(audio).max()
-        if max_value > 1.0:
-            audio = audio / max_value
-
-        return {
-            "audio": torch.from_numpy(audio),
-        }
-
-    def __getitem__(self, idx):
-        try:
-            return self.get_item(idx)
-        except Exception as e:
-            logger.error(f"Error loading {self.files[idx]}: {e}")
-            return None
-
-
-class MixDatast(IterableDataset):
-    def __init__(self, datasets: dict[str, dict], seed: int = 42) -> None:
-        values = list(datasets.values())
-        probs = [v["prob"] for v in values]
-        self.datasets = [v["dataset"] for v in values]
-
-        total_probs = sum(probs)
-        self.probs = [p / total_probs for p in probs]
-        self.seed = seed
-
-    def __iter__(self):
-        rng = np.random.default_rng(self.seed)
-        dataset_iterators = [iter(dataset) for dataset in self.datasets]
-
-        while True:
-            # Random choice one
-            dataset_idx = rng.choice(len(self.datasets), p=self.probs)
-            dataset_iterator = dataset_iterators[dataset_idx]
-
-            try:
-                yield next(dataset_iterator)
-            except StopIteration:
-                # Exhausted, create a new iterator
-                dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
-                yield next(dataset_iterators[dataset_idx])
-
-
-@dataclass
-class VQGANCollator:
-    def __call__(self, batch):
-        batch = [x for x in batch if x is not None]
-
-        audio_lengths = torch.tensor([len(x["audio"]) for x in batch])
-        audio_maxlen = audio_lengths.max()
-
-        # Rounds up to nearest multiple of 2 (audio_lengths)
-        audios = []
-        for x in batch:
-            audios.append(
-                torch.nn.functional.pad(x["audio"], (0, audio_maxlen - len(x["audio"])))
-            )
-
-        return {
-            "audios": torch.stack(audios),
-            "audio_lengths": audio_lengths,
-        }
-
-
-class VQGANDataModule(LightningDataModule):
-    def __init__(
-        self,
-        train_dataset: VQGANDataset,
-        val_dataset: VQGANDataset,
-        batch_size: int = 32,
-        num_workers: int = 4,
-        val_batch_size: Optional[int] = None,
-    ):
-        super().__init__()
-
-        self.train_dataset = train_dataset
-        self.val_dataset = val_dataset
-        self.batch_size = batch_size
-        self.val_batch_size = val_batch_size or batch_size
-        self.num_workers = num_workers
-
-    def train_dataloader(self):
-        return DataLoader(
-            self.train_dataset,
-            batch_size=self.batch_size,
-            collate_fn=VQGANCollator(),
-            num_workers=self.num_workers,
-            shuffle=not isinstance(self.train_dataset, IterableDataset),
-        )
-
-    def val_dataloader(self):
-        return DataLoader(
-            self.val_dataset,
-            batch_size=self.batch_size,
-            collate_fn=VQGANCollator(),
-            num_workers=self.num_workers,
-        )
-
-
-if __name__ == "__main__":
-    dataset = VQGANDataset("data/LibriTTS_R/vq_train_filelist.txt")
-    dataloader = DataLoader(
-        dataset, batch_size=4, shuffle=False, collate_fn=VQGANCollator()
-    )
-
-    for batch in dataloader:
-        print(batch["audios"].shape)
-        print(batch["features"].shape)
-        print(batch["audio_lengths"])
-        print(batch["feature_lengths"])
-        break

+ 5 - 3
tools/llama/build_dataset.py

@@ -1,5 +1,6 @@
+import glob
 import re
-from collections import defaultdict
+from collections import Counter, defaultdict
 from multiprocessing import Pool
 from pathlib import Path
 
@@ -32,10 +33,11 @@ def task_generator_yaml(config):
             parent_level = [parent_level]
 
         # Load the files
-        files = list_files(root, AUDIO_EXTENSIONS, recursive=True, sort=True)
+        files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}"))
+        files = sorted(files)
 
         grouped_files = defaultdict(list)
-        for file in files:
+        for file in tqdm(files, desc=f"Grouping {root}"):
             all_parents = []
             pointer = file
             while pointer.parent.name: