فهرست منبع

Rename project to fish-speech

Lengyue 2 سال پیش
والد
کامیت
6d0e669f8c

+ 1 - 1
.gitignore

@@ -1,6 +1,6 @@
 .pgx.*
 .pdm-python
-/speech_lm.egg-info
+/fish_speech.egg-info
 __pycache__
 /results
 /data

+ 9 - 9
speech_lm/configs/llama_finetune.yaml → fish_speech/configs/llama_finetune.yaml

@@ -43,15 +43,15 @@ schedule:
   clip_grad_norm: 1.0
 
 train_dataset:
-  _target_: speech_lm.datasets.cultura_x.InterleaveDataset
+  _target_: fish_speech.datasets.cultura_x.InterleaveDataset
   datasets:
-    - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+    - _target_: fish_speech.datasets.cultura_x.CulturaXDataset
       lang: 'en'
-    - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+    - _target_: fish_speech.datasets.cultura_x.CulturaXDataset
       lang: 'zh'
-    - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+    - _target_: fish_speech.datasets.cultura_x.CulturaXDataset
       lang: 'ja'
-    - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+    - _target_: fish_speech.datasets.cultura_x.CulturaXDataset
       repo: fishaudio/wenet-vq
       files:
         - data/train-00000-of-00018-b5a82c6054c6acca.parquet
@@ -81,21 +81,21 @@ train_dataloader:
   batch_size: ${schedule.micro_batch_size}
   num_workers: 8
   collate_fn:
-    _target_: speech_lm.datasets.cultura_x.CulutreXCollator
+    _target_: fish_speech.datasets.cultura_x.CulutreXCollator
     tokenizer: ${tokenizer}
     max_length: ${schedule.max_length}
 
 valid_dataloader:
   _target_: torch.utils.data.DataLoader
   dataset:
-    _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+    _target_: fish_speech.datasets.cultura_x.CulturaXDataset
     repo: fishaudio/wenet-vq
     files:
       - data/test-00000-of-00001-685250c116f5d321.parquet
   batch_size: ${schedule.micro_batch_size}
   num_workers: 1
   collate_fn:
-    _target_: speech_lm.datasets.cultura_x.CulutreXCollator
+    _target_: fish_speech.datasets.cultura_x.CulutreXCollator
     tokenizer: ${tokenizer}
     max_length: ${schedule.max_length}
 
@@ -109,7 +109,7 @@ optimizer:
 scheduler:
   _target_: torch.optim.lr_scheduler.LambdaLR
   lr_lambda:
-    _target_: speech_lm.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+    _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
     _partial_: true
     num_warmup_steps: 2000
     num_training_steps: ${schedule.max_steps}

+ 6 - 6
speech_lm/configs/llama_pretrain.yaml → fish_speech/configs/llama_pretrain.yaml

@@ -46,13 +46,13 @@ schedule:
   clip_grad_norm: 1.0
 
 dataset:
-  _target_: speech_lm.datasets.cultura_x.InterleaveDataset
+  _target_: fish_speech.datasets.cultura_x.InterleaveDataset
   datasets:
-    - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+    - _target_: fish_speech.datasets.cultura_x.CulturaXDataset
       lang: 'en'
-    - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+    - _target_: fish_speech.datasets.cultura_x.CulturaXDataset
       lang: 'zh'
-    - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
+    - _target_: fish_speech.datasets.cultura_x.CulturaXDataset
       lang: 'ja'
   probabilities: [0.4, 0.3, 0.3]
   seed: 42
@@ -63,7 +63,7 @@ train_dataloader:
   batch_size: ${schedule.micro_batch_size}
   num_workers: 8
   collate_fn:
-    _target_: speech_lm.datasets.cultura_x.CulutreXCollator
+    _target_: fish_speech.datasets.cultura_x.CulutreXCollator
     tokenizer: ${tokenizer}
     max_length: ${schedule.max_length}
 
@@ -77,7 +77,7 @@ optimizer:
 scheduler:
   _target_: torch.optim.lr_scheduler.LambdaLR
   lr_lambda:
-    _target_: speech_lm.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+    _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
     _partial_: true
     num_warmup_steps: 2000
     num_training_steps: ${schedule.max_steps}

+ 6 - 6
speech_lm/configs/whisper_vq.yaml → fish_speech/configs/whisper_vq.yaml

@@ -22,7 +22,7 @@ trainer:
     version: null
 
 model:
-  _target_: speech_lm.models.whisper_vq.WhisperVQ
+  _target_: fish_speech.models.whisper_vq.WhisperVQ
   model_name_or_path: "openai/whisper-medium"
 
   # Quantization
@@ -49,7 +49,7 @@ schedule:
 train_dataloader:
   _target_: torch.utils.data.DataLoader
   dataset:
-    _target_: speech_lm.datasets.whisper_vq.WhisperVQDataset
+    _target_: fish_speech.datasets.whisper_vq.WhisperVQDataset
     filelist: filelists/whisper-vq.train.filelist
   batch_size: ${schedule.micro_batch_size}
   num_workers: 16
@@ -58,12 +58,12 @@ train_dataloader:
   persistent_workers: true
   shuffle: true
   collate_fn:
-    _target_: speech_lm.datasets.whisper_vq.WhisperVQCollator
+    _target_: fish_speech.datasets.whisper_vq.WhisperVQCollator
 
 valid_dataloader:
   _target_: torch.utils.data.DataLoader
   dataset:
-    _target_: speech_lm.datasets.whisper_vq.WhisperVQDataset
+    _target_: fish_speech.datasets.whisper_vq.WhisperVQDataset
     filelist: filelists/whisper-vq.test.filelist
   batch_size: 16
   num_workers: 8
@@ -71,7 +71,7 @@ valid_dataloader:
   pin_memory: true
   shuffle: false
   collate_fn:
-    _target_: speech_lm.datasets.whisper_vq.WhisperVQCollator
+    _target_: fish_speech.datasets.whisper_vq.WhisperVQCollator
 
 optimizer:
   _target_: torch.optim.AdamW
@@ -83,7 +83,7 @@ optimizer:
 scheduler:
   _target_: torch.optim.lr_scheduler.LambdaLR
   lr_lambda:
-    _target_: speech_lm.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+    _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
     _partial_: true
     num_warmup_steps: 1000
     num_training_steps: ${schedule.max_steps}

+ 2 - 2
speech_lm/datasets/cultura_x.py → fish_speech/datasets/cultura_x.py

@@ -12,7 +12,7 @@ from torch.distributed import get_rank, get_world_size, is_initialized
 from torch.utils.data import IterableDataset, get_worker_info
 from transformers import AutoTokenizer
 
-from speech_lm.utils.braceexpand import braceexpand
+from fish_speech.utils.braceexpand import braceexpand
 
 SUBSETS = {
     "en": "en_part_{00000..03071}",
@@ -157,7 +157,7 @@ class InterleaveDataset(IterableDataset):
 if __name__ == "__main__":
     from torch.utils.data import DataLoader
 
-    from speech_lm.datasets.wenet_vq import WenetVQDataset
+    from fish_speech.datasets.wenet_vq import WenetVQDataset
 
     dataset_en = CulturaXDataset("en")
     dataset_ja = CulturaXDataset("ja")

+ 2 - 2
speech_lm/datasets/whisper_vq.py → fish_speech/datasets/whisper_vq.py

@@ -109,8 +109,8 @@ if __name__ == "__main__":
     from torch.utils.data import DataLoader
     from transformers import GenerationConfig
 
-    from speech_lm.models.whisper_vq import WhisperVQ
-    from speech_lm.modules.flash_whisper import FlashWhisperForConditionalGeneration
+    from fish_speech.models.whisper_vq import WhisperVQ
+    from fish_speech.modules.flash_whisper import FlashWhisperForConditionalGeneration
 
     dataset = WhisperVQDataset("filelists/whisper-vq.test.filelist")
     dataloader = DataLoader(

+ 0 - 0
speech_lm/logger.py → fish_speech/logger.py


+ 2 - 2
speech_lm/models/whisper_vq.py → fish_speech/models/whisper_vq.py

@@ -5,7 +5,7 @@ import torch
 from torch import nn
 from vector_quantize_pytorch import VectorQuantize
 
-from speech_lm.modules.flash_whisper import (
+from fish_speech.modules.flash_whisper import (
     FlashWhisperEncoderLayer,
     FlashWhisperForConditionalGeneration,
 )
@@ -200,7 +200,7 @@ if __name__ == "__main__":
     from torch.utils.data import DataLoader
     from transformers import WhisperProcessor
 
-    from speech_lm.datasets.whisper_vq import WhisperVQCollator, WhisperVQDataset
+    from fish_speech.datasets.whisper_vq import WhisperVQCollator, WhisperVQDataset
 
     processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
     model = WhisperVQ()

+ 0 - 0
speech_lm/modules/flash_whisper.py → fish_speech/modules/flash_whisper.py


+ 0 - 0
speech_lm/scheduler.py → fish_speech/scheduler.py


+ 292 - 0
fish_speech/train copy.py

@@ -0,0 +1,292 @@
+import time
+from collections import defaultdict
+from datetime import timedelta
+from pathlib import Path
+from typing import Optional
+
+import hydra
+import torch
+from lightning.fabric import Fabric
+from natsort import natsorted
+from omegaconf import DictConfig, OmegaConf
+from tqdm import tqdm
+from transformers import LlamaForCausalLM
+from transformers.utils import is_flash_attn_available
+
+from fish_speech.logger import RankedLogger
+
+# Allow TF32 on Ampere GPUs
+torch.set_float32_matmul_precision("high")
+torch.backends.cudnn.allow_tf32 = True
+
+# register eval resolver
+OmegaConf.register_new_resolver("eval", eval)
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def valid(
+    model: LlamaForCausalLM,
+    valid_dataloader: Optional[torch.utils.data.DataLoader],
+    global_step: int,
+    fabric: Fabric,
+    cfg: DictConfig,
+):
+    model.eval()
+    log.info(f"Evaluating at step {global_step}")
+
+    accumulate_infos = None
+
+    for idx, batch in enumerate(tqdm(valid_dataloader, desc="Evaluating")):
+        outputs = model(**batch)
+        loss = outputs.loss
+        metrics = getattr(outputs, "metrics", {})
+        log_info = {
+            "valid/loss": float(loss),
+            **{f"valid/{k}": float(v) for k, v in metrics.items()},
+        }
+
+        fabric.log_dict(
+            log_info,
+            step=global_step + idx,
+        )
+
+        # Update log info
+        if accumulate_infos is None:
+            accumulate_infos = log_info
+        else:
+            assert set(accumulate_infos.keys()) == set(
+                log_info.keys()
+            ), "Log keys changed during evaluation"
+            for k in accumulate_infos.keys():
+                accumulate_infos[k] += log_info[k]
+
+        if idx == getattr(cfg.schedule, "eval_max_batches", None):
+            break
+
+    # Log average
+    items = []
+    for k in accumulate_infos.keys():
+        items.append(f"{k}: {accumulate_infos[k] / (idx + 1):.4f}")
+    log.info(f"Average: {' | '.join(items)}")
+
+
+def train(
+    model: LlamaForCausalLM,
+    optimizer: torch.optim.Optimizer,
+    scheduler: torch.optim.lr_scheduler._LRScheduler,
+    train_dataloader: torch.utils.data.DataLoader,
+    valid_dataloader: Optional[torch.utils.data.DataLoader],
+    global_step: int,
+    fabric: Fabric,
+    cfg: DictConfig,
+):
+    accumulate_steps = 0
+    optimizer.zero_grad()
+
+    # Start time is ~model forward time + data loading time
+    start_time = time.time()
+    trackers = defaultdict(list)
+
+    while global_step < cfg.schedule.max_steps:
+        last_batch_time = time.time()
+        for batch in train_dataloader:
+            # Measure time used by data loading
+            trackers["data_time"].append(time.time() - last_batch_time)
+
+            # Measure time used by model forward
+            model_begin_time = time.time()
+            model.train()
+
+            # Accumulate gradients
+            gradient_accumulation_steps = cfg.schedule.gradient_accumulation_steps
+            is_accumulating = accumulate_steps % gradient_accumulation_steps != 0
+            accumulate_steps += 1
+
+            # Train one step
+            with fabric.no_backward_sync(model, enabled=is_accumulating):
+                outputs = model(**batch)
+                loss = outputs.loss
+                metrics = getattr(outputs, "metrics", {})
+
+                # Need to divide loss by accumulation steps
+                fabric.backward(loss / gradient_accumulation_steps)
+
+                # Update trackers
+                trackers["loss"].append(float(loss))
+                trackers["lr"].append(float(optimizer.param_groups[0]["lr"]))
+                for k, v in metrics.items():
+                    trackers[f"metrics/{k}"].append(float(v))
+
+            trackers["model_time"].append(time.time() - model_begin_time)
+
+            if is_accumulating:
+                last_batch_time = time.time()
+                continue
+
+            # Check all trackers has the same length
+            assert (
+                len(set(len(v) for k, v in trackers.items() if k != "grad_norm")) == 1
+            ), "Trackers has ambiguous length"
+
+            # Perform gradient clipping
+            grad_norm = fabric.clip_gradients(
+                model,
+                optimizer,
+                max_norm=cfg.schedule.clip_grad_norm,
+                norm_type=2.0,
+                error_if_nonfinite=True,
+            )
+
+            if torch.isnan(grad_norm) or torch.isinf(grad_norm):
+                log.warning(f"Gradient norm is {grad_norm}, skipping update")
+                optimizer.zero_grad()
+
+            # We can't average gradients across multiple steps
+            trackers["grad_norm"].append(float(grad_norm))
+
+            # Update
+            optimizer.step()
+            optimizer.zero_grad()
+            scheduler.step()
+
+            fabric.log_dict(
+                {
+                    f"train/{k}": sum(v[-gradient_accumulation_steps:])
+                    / len(v[-gradient_accumulation_steps:])
+                    for k, v in trackers.items()
+                },
+                step=global_step,
+            )
+
+            # accumulate_steps = 0
+            global_step += 1
+
+            if global_step % cfg.schedule.log_interval == 0:
+                step_time = (time.time() - start_time) / cfg.schedule.log_interval
+                eta = step_time * (cfg.schedule.max_steps - global_step)
+                additional_info = [
+                    f"{k}: {sum(v[-cfg.schedule.log_interval:]) / len(v[-cfg.schedule.log_interval:]):.4f}"
+                    for k, v in trackers.items()
+                    if k != "lr"  # lr use .2e format
+                ]
+
+                log.info(
+                    f"[{global_step}/{cfg.schedule.max_steps}] "
+                    + f"step_time: {step_time:.2f}s "
+                    + f"ETA: {timedelta(seconds=round(eta))}s "
+                    f"lr: {optimizer.param_groups[0]['lr']:.2e} "
+                    + " ".join(additional_info)
+                )
+
+                # Reset trackers
+                trackers = defaultdict(list)
+
+                start_time = time.time()
+
+            if global_step % cfg.schedule.save_interval == 0:
+                fabric.save(
+                    Path(cfg.paths.checkpoint_dir) / f"step_{global_step}.ckpt",
+                    {
+                        "model": model,
+                        "optimizer": optimizer,
+                        "scheduler": scheduler.state_dict(),
+                        "global_step": global_step,
+                    },
+                )
+
+            if (
+                getattr(cfg.schedule, "eval_interval", None) is not None
+                and global_step % cfg.schedule.eval_interval == 0
+                and valid_dataloader is not None
+            ):
+                valid(model, valid_dataloader, global_step, fabric, cfg)
+
+            if global_step >= cfg.schedule.max_steps:
+                break
+
+            last_batch_time = time.time()
+
+
+@hydra.main(
+    version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
+)
+def main(cfg: DictConfig):
+    log.info(f"Config: \n{OmegaConf.to_yaml(cfg)}")
+
+    if is_flash_attn_available() is False:
+        log.warning("Flash attention is not available, using default attention")
+
+    fabric: Fabric = hydra.utils.instantiate(cfg.trainer)
+    fabric.launch()
+    log.info(f"Fabric: {fabric}")
+
+    model = hydra.utils.instantiate(cfg.model)
+    log.info(f"Model: {repr(model)}")
+
+    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+    freeze_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
+    log.info(f"Trainable parameters: {trainable_params/1e6:.2f}M")
+    log.info(f"Freeze parameters: {freeze_params/1e6:.2f}M")
+
+    optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters())
+    scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
+    log.info(f"Optimizer: {optimizer}")
+    log.info(f"Scheduler: {scheduler}")
+
+    log.info(f"Setup fabric model & dataset")
+    model = fabric.setup_module(model)
+    optimizer = fabric.setup_optimizers(optimizer)
+
+    # Build state
+    global_step = 0
+
+    # Restore training from checkpoint
+    checkpoint_dir = Path(cfg.paths.checkpoint_dir)
+    checkpoint_dir.mkdir(parents=True, exist_ok=True)
+
+    # Alphabetically sort checkpoints
+    checkpoints = natsorted(checkpoint_dir.glob("*.ckpt"))
+    if len(checkpoints) > 0:
+        checkpoint_path = checkpoints[-1]
+
+        log.info(f"Restoring checkpoint from {checkpoint_path}")
+        remainder = fabric.load(
+            checkpoint_path,
+            {
+                "model": model,
+                "optimizer": optimizer,
+                "scheduler": scheduler,
+            },
+        )
+        global_step = remainder["global_step"]
+        log.info(f"Restored global step: {global_step}")
+
+    train_dataloader = hydra.utils.instantiate(cfg.train_dataloader)
+    log.info(f"Train Dataloader: {train_dataloader}")
+
+    valid_dataloader = None
+    if getattr(cfg, "valid_dataloader", None) is not None:
+        valid_dataloader = hydra.utils.instantiate(cfg.valid_dataloader)
+        log.info(f"Valid Dataloader: {valid_dataloader}")
+
+    train_dataloader = fabric.setup_dataloaders(train_dataloader)
+    if valid_dataloader is not None:
+        valid_dataloader = fabric.setup_dataloaders(valid_dataloader)
+
+    log.info(f"Begin training")
+
+    train(
+        model=model,
+        optimizer=optimizer,
+        scheduler=scheduler,
+        train_dataloader=train_dataloader,
+        valid_dataloader=valid_dataloader,
+        global_step=global_step,
+        fabric=fabric,
+        cfg=cfg,
+    )
+
+
+if __name__ == "__main__":
+    main()

+ 1 - 1
speech_lm/train.py → fish_speech/train.py

@@ -13,7 +13,7 @@ from tqdm import tqdm
 from transformers import LlamaForCausalLM
 from transformers.utils import is_flash_attn_available
 
-from speech_lm.logger import RankedLogger
+from fish_speech.logger import RankedLogger
 
 # Allow TF32 on Ampere GPUs
 torch.set_float32_matmul_precision("high")

+ 0 - 0
speech_lm/utils/braceexpand.py → fish_speech/utils/braceexpand.py


+ 3 - 3
setup.py

@@ -1,7 +1,7 @@
 from setuptools import find_packages, setup
 
 setup(
-    name="speech-lm",
-    version="0.0.1",
-    packages=find_packages(include=["speech_lm", "speech_lm.*"]),
+    name="fish-speech",
+    version="0.1.0",
+    packages=find_packages(include=["fish_speech", "fish_speech.*"]),
 )

+ 1 - 1
tools/whisper_asr.py

@@ -16,7 +16,7 @@ from loguru import logger
 from transformers import WhisperProcessor
 from whisper.audio import load_audio, log_mel_spectrogram, pad_or_trim
 
-from speech_lm.modules.flash_whisper import FlashWhisperForConditionalGeneration
+from fish_speech.modules.flash_whisper import FlashWhisperForConditionalGeneration
 
 RANK_STR = ""