Procházet zdrojové kódy

Support pytorch lightning

Lengyue před 2 roky
rodič
revize
f7f2c03282

+ 3 - 0
fish_speech/callbacks/__init__.py

@@ -0,0 +1,3 @@
+from .grad_norm import GradNormMonitor
+
+__all__ = ["GradNormMonitor"]

+ 86 - 0
fish_speech/callbacks/grad_norm.py

@@ -0,0 +1,86 @@
+from typing import Union
+
+import lightning.pytorch as pl
+import torch
+from lightning import LightningModule, Trainer
+from lightning.pytorch.callbacks import Callback
+from torch import Tensor
+from torch.utils._foreach_utils import (
+    _group_tensors_by_device_and_dtype,
+    _has_foreach_support,
+)
+
+
+@torch.no_grad()
+def grad_norm(
+    parameters: Union[Tensor, list[Tensor]],
+    norm_type: float = 2.0,
+) -> float:
+    """
+    Returns the norm of the gradients of the given parameters.
+
+    Args:
+        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+            single Tensor that will have gradients normalized
+        norm_type (float): type of the used p-norm.
+
+    Returns:
+        Total norm of the parameter gradients (viewed as a single vector).
+    """  # noqa: E501
+
+    if isinstance(parameters, Tensor):
+        parameters = [parameters]
+
+    grads = [p.grad for p in parameters if p.grad is not None]
+    first_device = grads[0].device
+    grouped_grads: dict[
+        tuple[torch.device, torch.dtype], list[list[Tensor]]
+    ] = _group_tensors_by_device_and_dtype(
+        [[g.detach() for g in grads]]
+    )  # type: ignore[assignment]
+
+    norms = []
+    for (device, _), ([grads], _) in grouped_grads.items():
+        if _has_foreach_support(grads, device=device):
+            norms.extend(torch._foreach_norm(grads, norm_type))
+        else:
+            norms.extend([torch.norm(g, norm_type) for g in grads])
+
+    return torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
+
+
+class GradNormMonitor(Callback):
+    """
+    Callback that computes the gradient norm of the model parameters.
+    """
+
+    def __init__(self, norm_type: float = 2.0, logging_interval: str = "step") -> None:
+        """
+        Args:
+            norm_type (float): type of the used p-norm.
+            logging_interval (str): "step" or "epoch".
+        """
+        super().__init__()
+        self.norm_type = norm_type
+        self.logging_interval = logging_interval
+
+    def on_after_backward(self, trainer: Trainer, model: LightningModule) -> None:
+        """
+        Computes the gradient norm of the model parameters and logs it to the logger.
+
+        Args:
+            trainer (Trainer): The trainer object
+            model (LightningModule): The current lightningModule
+        """
+
+        grad_norm_val = grad_norm(model.parameters(), self.norm_type)
+
+        model_name = model.__class__.__name__.lower()
+
+        on_step = self.logging_interval == "step"
+        model.log(
+            f"train/{model_name}/grad_norm",
+            grad_norm_val,
+            on_step=on_step,
+            on_epoch=not on_step,
+        )

+ 74 - 0
fish_speech/configs/base.yaml

@@ -0,0 +1,74 @@
+# Base configuration for training a model
+paths:
+  run_dir: results/${project}
+  ckpt_dir: ${paths.run_dir}/checkpoints
+
+hydra:
+  run:
+    dir: ${paths.run_dir}
+
+# Lightning Trainer
+trainer:
+  _target_: lightning.pytorch.trainer.Trainer
+
+  default_root_dir: ${paths.run_dir}
+  accelerator: gpu
+  num_nodes: 1
+  devices: 8
+  strategy:
+    _target_: lightning.pytorch.strategies.DDPStrategy
+    static_graph: true
+  precision: bf16-mixed
+
+  # disable validation by epoch end
+  check_val_every_n_epoch: null
+  val_check_interval: 5000
+  max_steps: 100_000
+
+  # Use torch.backends.cudnn.benchmark to speed up training
+  benchmark: true
+
+# Callbacks
+callbacks:
+  model_checkpoint:
+    _target_: lightning.pytorch.callbacks.ModelCheckpoint
+    dirpath: ${paths.ckpt_dir}
+    filename: "step_{step:09d}"
+    save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt
+    save_top_k: 5 # save 5 latest checkpoints
+    monitor: step # use step to monitor checkpoints
+    mode: max # save the latest checkpoint with the highest global_step
+    every_n_epochs: null # don't save checkpoints by epoch end
+    every_n_train_steps: 5000 # save checkpoints every 5000 steps
+    auto_insert_metric_name: false
+
+  model_summary:
+    _target_: lightning.pytorch.callbacks.RichModelSummary
+    max_depth: 2 # the maximum depth of layer nesting that the summary will include
+
+  rich_progress_bar:
+    _target_: lightning.pytorch.callbacks.RichProgressBar
+
+  learning_rate_monitor:
+    _target_: lightning.pytorch.callbacks.LearningRateMonitor
+    logging_interval: step
+    log_momentum: false
+
+  grad_norm_monitor:
+    _target_: fish_speech.callbacks.GradNormMonitor
+    norm_type: 2
+    logging_interval: step
+
+# Logger
+logger:
+  tensorboard:
+    _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
+    save_dir: "${paths.run_dir}/tensorboard/"
+    name: null
+    log_graph: false
+    default_hp_metric: true
+    prefix: ""
+
+# Loop
+train: true
+test: false

+ 53 - 96
fish_speech/configs/llama_finetune.yaml

@@ -1,116 +1,73 @@
-paths:
-  run_dir: results/finetune
-  checkpoint_dir: ${paths.run_dir}/checkpoints
+defaults:
+  - base
+  - _self_
 
-hydra:
-  run:
-    dir: ${paths.run_dir}
+project: llama_finetune
 
+# Lightning Trainer
 trainer:
-  _target_: lightning.fabric.Fabric
-  accelerator: gpu
-  strategy:
-    _target_: lightning.fabric.strategies.DDPStrategy
-    static_graph: true
-  num_nodes: 1
-  devices: 8
-  precision: bf16-mixed
-  loggers:
-    _target_: pytorch_lightning.loggers.TensorBoardLogger
-    save_dir: ${paths.run_dir}
-    name: tensorboard
-    version: null
-
-model:
-  _target_: transformers.AutoModelForCausalLM.from_pretrained
-  pretrained_model_name_or_path: fishaudio/speech-lm-300m
-  revision: text-pretrain-10k
+  accumulate_grad_batches: 2
+  gradient_clip_val: 1.0
+  gradient_clip_algorithm: 'norm'
 
+# Dataset Configuration
 tokenizer:
   _target_: transformers.AutoTokenizer.from_pretrained
   pretrained_model_name_or_path: fishaudio/speech-lm-300m
   revision: text-pretrain-10k
 
-# This is a 200 billion seen token schedule
-schedule:
-  max_length: 1024
-  batch_size: 16  # 128 * 4 = 512
-  micro_batch_size: 8
-  max_steps: 100000
-  save_interval: 5000
-  log_interval: 10
-  gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
-  clip_grad_norm: 1.0
-
+# Dataset Configuration
 train_dataset:
-  _target_: fish_speech.datasets.cultura_x.InterleaveDataset
+  _target_: fish_speech.datasets.text.InterleaveDataset
   datasets:
-    - _target_: fish_speech.datasets.cultura_x.CulturaXDataset
-      lang: 'en'
-    - _target_: fish_speech.datasets.cultura_x.CulturaXDataset
-      lang: 'zh'
-    - _target_: fish_speech.datasets.cultura_x.CulturaXDataset
-      lang: 'ja'
-    - _target_: fish_speech.datasets.cultura_x.CulturaXDataset
+    - _target_: fish_speech.datasets.text.TextDataset
+      prefix: 'en/'
+    - _target_: fish_speech.datasets.text.TextDataset
+      prefix: 'zh/'
+    - _target_: fish_speech.datasets.text.TextDataset
+      prefix: 'ja/'
+    - _target_: fish_speech.datasets.text.TextDataset
       repo: fishaudio/wenet-vq
-      files:
-        - data/train-00000-of-00018-b5a82c6054c6acca.parquet
-        - data/train-00001-of-00018-82467b3e0669c2be.parquet
-        - data/train-00002-of-00018-d50ed8c218a1f183.parquet
-        - data/train-00003-of-00018-15d666053eade100.parquet
-        - data/train-00004-of-00018-01868cb8408e012b.parquet
-        - data/train-00005-of-00018-e766a0b54b1fd08b.parquet
-        - data/train-00006-of-00018-c79fad54ea8a0b8d.parquet
-        - data/train-00007-of-00018-e4155011a7081a1d.parquet
-        - data/train-00008-of-00018-8ba319f5af359d15.parquet
-        - data/train-00009-of-00018-9c9e984a6565b2c3.parquet
-        - data/train-00010-of-00018-7af80a80e5aa1e54.parquet
-        - data/train-00011-of-00018-2ab91221787a84a3.parquet
-        - data/train-00012-of-00018-4d477812eea5d298.parquet
-        - data/train-00013-of-00018-faf87b68b1ab4a15.parquet
-        - data/train-00014-of-00018-7f6bbd9bcb4cbb55.parquet
-        - data/train-00015-of-00018-d630fe4a488b9f51.parquet
-        - data/train-00016-of-00018-969a4d5dc04d2764.parquet
-        - data/train-00017-of-00018-bbfd09175809d1fe.parquet
+      prefix: 'data/train'
   probabilities: [0.2, 0.2, 0.2, 0.4]
   seed: 42
 
-train_dataloader:
-  _target_: torch.utils.data.DataLoader
-  dataset: ${train_dataset}
-  batch_size: ${schedule.micro_batch_size}
-  num_workers: 8
-  collate_fn:
-    _target_: fish_speech.datasets.cultura_x.CulutreXCollator
-    tokenizer: ${tokenizer}
-    max_length: ${schedule.max_length}
+val_dataset:
+  _target_: fish_speech.datasets.text.TextDataset
+  repo: fishaudio/wenet-vq
+  prefix: 'data/test'
 
-valid_dataloader:
-  _target_: torch.utils.data.DataLoader
-  dataset:
-    _target_: fish_speech.datasets.cultura_x.CulturaXDataset
-    repo: fishaudio/wenet-vq
-    files:
-      - data/test-00000-of-00001-685250c116f5d321.parquet
-  batch_size: ${schedule.micro_batch_size}
+data:
+  _target_: fish_speech.datasets.text.TextDataModule
+  train_dataset: ${train_dataset}
+  val_dataset: ${val_dataset}
   num_workers: 1
-  collate_fn:
-    _target_: fish_speech.datasets.cultura_x.CulutreXCollator
-    tokenizer: ${tokenizer}
-    max_length: ${schedule.max_length}
+  batch_size: 8
+  tokenizer: ${tokenizer}
 
-optimizer:
-  _target_: torch.optim.AdamW
-  lr: 1e-4
-  weight_decay: 0.1
-  betas: [0.9, 0.95]
-  eps: 1e-5
+# Model Configuration
+model:
+  _target_: fish_speech.models.text2semantic.TextToSemantic
+
+  model:
+    _target_: transformers.AutoModelForCausalLM.from_pretrained
+    pretrained_model_name_or_path: fishaudio/speech-lm-300m
+    revision: text-pretrain-10k
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 1e-4
+    weight_decay: 0.1
+    betas: [0.9, 0.95]
+    eps: 1e-5
 
-scheduler:
-  _target_: torch.optim.lr_scheduler.LambdaLR
-  lr_lambda:
-    _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.LambdaLR
     _partial_: true
-    num_warmup_steps: 2000
-    num_training_steps: ${schedule.max_steps}
-    final_lr_ratio: 0.1
+    lr_lambda:
+      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+      _partial_: true
+      num_warmup_steps: 2000
+      num_training_steps: ${trainer.max_steps}
+      final_lr_ratio: 0.1

+ 85 - 29
fish_speech/datasets/cultura_x.py → fish_speech/datasets/text.py

@@ -1,51 +1,55 @@
 import random
 from dataclasses import dataclass
-from logging import getLogger
+from itertools import chain
 from random import Random
-from typing import Optional
+from typing import Optional, Union
 
 import numpy as np
-import pandas as pd
 import pyarrow.parquet as pq
 from datasets.download.streaming_download_manager import xopen
+from huggingface_hub import HfApi
+from lightning import LightningDataModule
 from torch.distributed import get_rank, get_world_size, is_initialized
-from torch.utils.data import IterableDataset, get_worker_info
+from torch.utils.data import DataLoader, IterableDataset, get_worker_info
 from transformers import AutoTokenizer
 
+from fish_speech.utils import RankedLogger
 from fish_speech.utils.braceexpand import braceexpand
 
-SUBSETS = {
-    "en": "en_part_{00000..03071}",
-    "zh": "zh_part_{00000..00319}",
-    "ja": "ja_part_{00000..00159}",
-}
+log = RankedLogger(__name__, rank_zero_only=True)
 
-log = getLogger(__name__)
 
-
-class CulturaXDataset(IterableDataset):
+class TextDataset(IterableDataset):
     def __init__(
         self,
-        lang: Optional[str] = None,
+        files: Optional[Union[list[str], str]] = None,
+        prefix: Optional[str] = None,
         seed: int = 42,
         parquet_batch_size: int = 10000,
         repo: str = "uonlp/CulturaX",
-        files: Optional[list[str]] = None,
     ):
         super().__init__()
 
-        self.lang = lang
         self.seed = seed
         self.parquet_batch_size = parquet_batch_size
         self.repo = repo
 
-        if self.lang is not None:
-            files = sorted(list(braceexpand(f"{lang}/{SUBSETS[lang]}.parquet")))
+        if files is None and prefix is None:
+            raise ValueError("Either files or prefix must be specified")
+
+        if prefix is not None:
+            files = HfApi().list_repo_files(repo, repo_type="dataset")
+            files = [f for f in files if f.startswith(prefix)]
+            log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
         else:
-            files = list(files)
+            if isinstance(files, str):
+                files = [files]
+
+            files = list(chain.from_iterable(map(braceexpand, files)))
+            log.info(f"Expanded {len(files)} files in {repo}")
 
         # Get sharded files
-        self.files = files
+        self.files = sorted(files)
         Random(seed).shuffle(self.files)
 
     def get_data_splits(self, files):
@@ -100,7 +104,7 @@ class CulturaXDataset(IterableDataset):
 
 
 @dataclass
-class CulutreXCollator:
+class TextDataCollator:
     tokenizer: AutoTokenizer
     max_length: int = 512
 
@@ -154,16 +158,68 @@ class InterleaveDataset(IterableDataset):
                 yield next(dataset_iterators[dataset_idx])
 
 
-if __name__ == "__main__":
-    from torch.utils.data import DataLoader
+class TextDataModule(LightningDataModule):
+    def __init__(
+        self,
+        train_dataset: Union[TextDataset, InterleaveDataset],
+        val_dataset: Optional[Union[TextDataset, InterleaveDataset]] = None,
+        batch_size: int = 32,
+        tokenizer: AutoTokenizer = None,
+        max_length: int = 1024,
+        num_workers: int = 4,
+    ):
+        super().__init__()
+
+        self.train_dataset = train_dataset
+        self.val_dataset = val_dataset
+        self.batch_size = batch_size
+        self.tokenizer = tokenizer
+        self.max_length = max_length
+        self.num_workers = num_workers
+
+    def train_dataloader(self):
+        return DataLoader(
+            self.train_dataset,
+            batch_size=self.batch_size,
+            collate_fn=TextDataCollator(self.tokenizer, self.max_length),
+            num_workers=self.num_workers,
+        )
 
-    from fish_speech.datasets.wenet_vq import WenetVQDataset
+    def val_dataloader(self):
+        if self.val_dataset is None:
+            return None
+
+        return DataLoader(
+            self.val_dataset,
+            batch_size=self.batch_size,
+            collate_fn=TextDataCollator(self.tokenizer, self.max_length),
+            num_workers=self.num_workers,
+        )
 
-    dataset_en = CulturaXDataset("en")
-    dataset_ja = CulturaXDataset("ja")
-    dataset_wenet = WenetVQDataset()
-    dataset = InterleaveDataset([dataset_en, dataset_wenet], [0.5, 0.5])
-    collator = CulutreXCollator(AutoTokenizer.from_pretrained("gpt2"))
 
-    for batch in DataLoader(dataset, batch_size=4, collate_fn=collator, num_workers=4):
+if __name__ == "__main__":
+    dm = TextDataModule(
+        InterleaveDataset(
+            datasets=[
+                TextDataset(
+                    prefix="en/en_part_",
+                ),
+                TextDataset(
+                    prefix="zh/zh_part_",
+                ),
+                TextDataset(
+                    prefix="ja/ja_part_",
+                ),
+            ],
+            probabilities=[0.8, 0.1, 0.1],
+        ),
+        TextDataset(
+            files="ja/ja_part_{00000..00159}",
+        ),
+        batch_size=2,
+        tokenizer=AutoTokenizer.from_pretrained("bert-base-multilingual-cased"),
+    )
+
+    for batch in dm.train_dataloader():
         print(batch)
+        break

+ 3 - 0
fish_speech/models/text2semantic/__init__.py

@@ -0,0 +1,3 @@
+from .lit_module import TextToSemantic
+
+__all__ = ["TextToSemantic"]

+ 64 - 0
fish_speech/models/text2semantic/lit_module.py

@@ -0,0 +1,64 @@
+from typing import Any
+
+import lightning as L
+from lightning.pytorch.utilities.types import OptimizerLRScheduler
+from transformers import LlamaForCausalLM
+
+
+class TextToSemantic(L.LightningModule):
+    def __init__(self, model: LlamaForCausalLM, optimizer: Any, lr_scheduler: Any):
+        super().__init__()
+
+        self.model = model
+        self.optimizer_builder = optimizer
+        self.lr_scheduler_builder = lr_scheduler
+
+    def forward(self, x):
+        return self.model(x)
+
+    def configure_optimizers(self) -> OptimizerLRScheduler:
+        optimizer = self.optimizer_builder(self.parameters())
+        lr_scheduler = self.lr_scheduler_builder(optimizer)
+
+        return {
+            "optimizer": optimizer,
+            "lr_scheduler": {
+                "scheduler": lr_scheduler,
+                "interval": "step",
+            },
+        }
+
+    def _step(self, batch, batch_idx, stage: str):
+        result = self.model(**batch)
+        loss = result.loss
+        logits = result.logits
+
+        self.log(
+            f"{stage}/loss",
+            loss,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=True,
+            logger=True,
+        )
+
+        # Top-5 accuracy
+        _, indices = logits.topk(5, dim=-1)
+        correct = indices.eq(batch["labels"].unsqueeze(-1)).sum()
+        accuracy = correct / batch["labels"].numel()
+        self.log(
+            f"{stage}/accuracy",
+            accuracy,
+            on_step=True,
+            on_epoch=False,
+            prog_bar=True,
+            logger=True,
+        )
+
+        return loss
+
+    def training_step(self, batch, batch_idx):
+        return self._step(batch, batch_idx, "train")
+
+    def validation_step(self, batch, batch_idx):
+        return self._step(batch, batch_idx, "val")

+ 0 - 292
fish_speech/train copy.py

@@ -1,292 +0,0 @@
-import time
-from collections import defaultdict
-from datetime import timedelta
-from pathlib import Path
-from typing import Optional
-
-import hydra
-import torch
-from lightning.fabric import Fabric
-from natsort import natsorted
-from omegaconf import DictConfig, OmegaConf
-from tqdm import tqdm
-from transformers import LlamaForCausalLM
-from transformers.utils import is_flash_attn_available
-
-from fish_speech.logger import RankedLogger
-
-# Allow TF32 on Ampere GPUs
-torch.set_float32_matmul_precision("high")
-torch.backends.cudnn.allow_tf32 = True
-
-# register eval resolver
-OmegaConf.register_new_resolver("eval", eval)
-
-log = RankedLogger(__name__, rank_zero_only=True)
-
-
-def valid(
-    model: LlamaForCausalLM,
-    valid_dataloader: Optional[torch.utils.data.DataLoader],
-    global_step: int,
-    fabric: Fabric,
-    cfg: DictConfig,
-):
-    model.eval()
-    log.info(f"Evaluating at step {global_step}")
-
-    accumulate_infos = None
-
-    for idx, batch in enumerate(tqdm(valid_dataloader, desc="Evaluating")):
-        outputs = model(**batch)
-        loss = outputs.loss
-        metrics = getattr(outputs, "metrics", {})
-        log_info = {
-            "valid/loss": float(loss),
-            **{f"valid/{k}": float(v) for k, v in metrics.items()},
-        }
-
-        fabric.log_dict(
-            log_info,
-            step=global_step + idx,
-        )
-
-        # Update log info
-        if accumulate_infos is None:
-            accumulate_infos = log_info
-        else:
-            assert set(accumulate_infos.keys()) == set(
-                log_info.keys()
-            ), "Log keys changed during evaluation"
-            for k in accumulate_infos.keys():
-                accumulate_infos[k] += log_info[k]
-
-        if idx == getattr(cfg.schedule, "eval_max_batches", None):
-            break
-
-    # Log average
-    items = []
-    for k in accumulate_infos.keys():
-        items.append(f"{k}: {accumulate_infos[k] / (idx + 1):.4f}")
-    log.info(f"Average: {' | '.join(items)}")
-
-
-def train(
-    model: LlamaForCausalLM,
-    optimizer: torch.optim.Optimizer,
-    scheduler: torch.optim.lr_scheduler._LRScheduler,
-    train_dataloader: torch.utils.data.DataLoader,
-    valid_dataloader: Optional[torch.utils.data.DataLoader],
-    global_step: int,
-    fabric: Fabric,
-    cfg: DictConfig,
-):
-    accumulate_steps = 0
-    optimizer.zero_grad()
-
-    # Start time is ~model forward time + data loading time
-    start_time = time.time()
-    trackers = defaultdict(list)
-
-    while global_step < cfg.schedule.max_steps:
-        last_batch_time = time.time()
-        for batch in train_dataloader:
-            # Measure time used by data loading
-            trackers["data_time"].append(time.time() - last_batch_time)
-
-            # Measure time used by model forward
-            model_begin_time = time.time()
-            model.train()
-
-            # Accumulate gradients
-            gradient_accumulation_steps = cfg.schedule.gradient_accumulation_steps
-            is_accumulating = accumulate_steps % gradient_accumulation_steps != 0
-            accumulate_steps += 1
-
-            # Train one step
-            with fabric.no_backward_sync(model, enabled=is_accumulating):
-                outputs = model(**batch)
-                loss = outputs.loss
-                metrics = getattr(outputs, "metrics", {})
-
-                # Need to divide loss by accumulation steps
-                fabric.backward(loss / gradient_accumulation_steps)
-
-                # Update trackers
-                trackers["loss"].append(float(loss))
-                trackers["lr"].append(float(optimizer.param_groups[0]["lr"]))
-                for k, v in metrics.items():
-                    trackers[f"metrics/{k}"].append(float(v))
-
-            trackers["model_time"].append(time.time() - model_begin_time)
-
-            if is_accumulating:
-                last_batch_time = time.time()
-                continue
-
-            # Check all trackers has the same length
-            assert (
-                len(set(len(v) for k, v in trackers.items() if k != "grad_norm")) == 1
-            ), "Trackers has ambiguous length"
-
-            # Perform gradient clipping
-            grad_norm = fabric.clip_gradients(
-                model,
-                optimizer,
-                max_norm=cfg.schedule.clip_grad_norm,
-                norm_type=2.0,
-                error_if_nonfinite=True,
-            )
-
-            if torch.isnan(grad_norm) or torch.isinf(grad_norm):
-                log.warning(f"Gradient norm is {grad_norm}, skipping update")
-                optimizer.zero_grad()
-
-            # We can't average gradients across multiple steps
-            trackers["grad_norm"].append(float(grad_norm))
-
-            # Update
-            optimizer.step()
-            optimizer.zero_grad()
-            scheduler.step()
-
-            fabric.log_dict(
-                {
-                    f"train/{k}": sum(v[-gradient_accumulation_steps:])
-                    / len(v[-gradient_accumulation_steps:])
-                    for k, v in trackers.items()
-                },
-                step=global_step,
-            )
-
-            # accumulate_steps = 0
-            global_step += 1
-
-            if global_step % cfg.schedule.log_interval == 0:
-                step_time = (time.time() - start_time) / cfg.schedule.log_interval
-                eta = step_time * (cfg.schedule.max_steps - global_step)
-                additional_info = [
-                    f"{k}: {sum(v[-cfg.schedule.log_interval:]) / len(v[-cfg.schedule.log_interval:]):.4f}"
-                    for k, v in trackers.items()
-                    if k != "lr"  # lr use .2e format
-                ]
-
-                log.info(
-                    f"[{global_step}/{cfg.schedule.max_steps}] "
-                    + f"step_time: {step_time:.2f}s "
-                    + f"ETA: {timedelta(seconds=round(eta))}s "
-                    f"lr: {optimizer.param_groups[0]['lr']:.2e} "
-                    + " ".join(additional_info)
-                )
-
-                # Reset trackers
-                trackers = defaultdict(list)
-
-                start_time = time.time()
-
-            if global_step % cfg.schedule.save_interval == 0:
-                fabric.save(
-                    Path(cfg.paths.checkpoint_dir) / f"step_{global_step}.ckpt",
-                    {
-                        "model": model,
-                        "optimizer": optimizer,
-                        "scheduler": scheduler.state_dict(),
-                        "global_step": global_step,
-                    },
-                )
-
-            if (
-                getattr(cfg.schedule, "eval_interval", None) is not None
-                and global_step % cfg.schedule.eval_interval == 0
-                and valid_dataloader is not None
-            ):
-                valid(model, valid_dataloader, global_step, fabric, cfg)
-
-            if global_step >= cfg.schedule.max_steps:
-                break
-
-            last_batch_time = time.time()
-
-
-@hydra.main(
-    version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
-)
-def main(cfg: DictConfig):
-    log.info(f"Config: \n{OmegaConf.to_yaml(cfg)}")
-
-    if is_flash_attn_available() is False:
-        log.warning("Flash attention is not available, using default attention")
-
-    fabric: Fabric = hydra.utils.instantiate(cfg.trainer)
-    fabric.launch()
-    log.info(f"Fabric: {fabric}")
-
-    model = hydra.utils.instantiate(cfg.model)
-    log.info(f"Model: {repr(model)}")
-
-    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
-    freeze_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
-    log.info(f"Trainable parameters: {trainable_params/1e6:.2f}M")
-    log.info(f"Freeze parameters: {freeze_params/1e6:.2f}M")
-
-    optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters())
-    scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
-    log.info(f"Optimizer: {optimizer}")
-    log.info(f"Scheduler: {scheduler}")
-
-    log.info(f"Setup fabric model & dataset")
-    model = fabric.setup_module(model)
-    optimizer = fabric.setup_optimizers(optimizer)
-
-    # Build state
-    global_step = 0
-
-    # Restore training from checkpoint
-    checkpoint_dir = Path(cfg.paths.checkpoint_dir)
-    checkpoint_dir.mkdir(parents=True, exist_ok=True)
-
-    # Alphabetically sort checkpoints
-    checkpoints = natsorted(checkpoint_dir.glob("*.ckpt"))
-    if len(checkpoints) > 0:
-        checkpoint_path = checkpoints[-1]
-
-        log.info(f"Restoring checkpoint from {checkpoint_path}")
-        remainder = fabric.load(
-            checkpoint_path,
-            {
-                "model": model,
-                "optimizer": optimizer,
-                "scheduler": scheduler,
-            },
-        )
-        global_step = remainder["global_step"]
-        log.info(f"Restored global step: {global_step}")
-
-    train_dataloader = hydra.utils.instantiate(cfg.train_dataloader)
-    log.info(f"Train Dataloader: {train_dataloader}")
-
-    valid_dataloader = None
-    if getattr(cfg, "valid_dataloader", None) is not None:
-        valid_dataloader = hydra.utils.instantiate(cfg.valid_dataloader)
-        log.info(f"Valid Dataloader: {valid_dataloader}")
-
-    train_dataloader = fabric.setup_dataloaders(train_dataloader)
-    if valid_dataloader is not None:
-        valid_dataloader = fabric.setup_dataloaders(valid_dataloader)
-
-    log.info(f"Begin training")
-
-    train(
-        model=model,
-        optimizer=optimizer,
-        scheduler=scheduler,
-        train_dataloader=train_dataloader,
-        valid_dataloader=valid_dataloader,
-        global_step=global_step,
-        fabric=fabric,
-        cfg=cfg,
-    )
-
-
-if __name__ == "__main__":
-    main()

+ 99 - 266
fish_speech/train.py

@@ -1,19 +1,13 @@
-import time
-from collections import defaultdict
-from datetime import timedelta
-from pathlib import Path
 from typing import Optional
 
 import hydra
+import lightning as L
 import torch
-from lightning.fabric import Fabric
-from natsort import natsorted
+from lightning import Callback, LightningDataModule, LightningModule, Trainer
+from lightning.pytorch.loggers import Logger
 from omegaconf import DictConfig, OmegaConf
-from tqdm import tqdm
-from transformers import LlamaForCausalLM
-from transformers.utils import is_flash_attn_available
 
-from fish_speech.logger import RankedLogger
+import fish_speech.utils as utils
 
 # Allow TF32 on Ampere GPUs
 torch.set_float32_matmul_precision("high")
@@ -22,270 +16,109 @@ torch.backends.cudnn.allow_tf32 = True
 # register eval resolver
 OmegaConf.register_new_resolver("eval", eval)
 
-log = RankedLogger(__name__, rank_zero_only=True)
-
-
-def valid(
-    model: LlamaForCausalLM,
-    valid_dataloader: Optional[torch.utils.data.DataLoader],
-    global_step: int,
-    fabric: Fabric,
-    cfg: DictConfig,
-):
-    model.eval()
-    log.info(f"Evaluating at step {global_step}")
-
-    accumulate_infos = None
-
-    for idx, batch in enumerate(tqdm(valid_dataloader, desc="Evaluating")):
-        outputs = model(**batch)
-        loss = outputs.loss
-        metrics = getattr(outputs, "metrics", {})
-        log_info = {
-            "valid/loss": float(loss),
-            **{f"valid/{k}": float(v) for k, v in metrics.items()},
-        }
-
-        fabric.log_dict(
-            log_info,
-            step=global_step + idx,
-        )
-
-        # Update log info
-        if accumulate_infos is None:
-            accumulate_infos = log_info
-        else:
-            assert set(accumulate_infos.keys()) == set(
-                log_info.keys()
-            ), "Log keys changed during evaluation"
-            for k in accumulate_infos.keys():
-                accumulate_infos[k] += log_info[k]
-
-        if idx == getattr(cfg.schedule, "eval_max_batches", None):
-            break
-
-    # Log average
-    items = []
-    for k in accumulate_infos.keys():
-        items.append(f"{k}: {accumulate_infos[k] / (idx + 1):.4f}")
-    log.info(f"Average: {' | '.join(items)}")
-
-
-def train(
-    model: LlamaForCausalLM,
-    optimizer: torch.optim.Optimizer,
-    scheduler: torch.optim.lr_scheduler._LRScheduler,
-    train_dataloader: torch.utils.data.DataLoader,
-    valid_dataloader: Optional[torch.utils.data.DataLoader],
-    global_step: int,
-    fabric: Fabric,
-    cfg: DictConfig,
-):
-    accumulate_steps = 0
-    optimizer.zero_grad()
-
-    # Start time is ~model forward time + data loading time
-    start_time = time.time()
-    trackers = defaultdict(list)
-
-    while global_step < cfg.schedule.max_steps:
-        last_batch_time = time.time()
-        for batch in train_dataloader:
-            # Measure time used by data loading
-            trackers["data_time"].append(time.time() - last_batch_time)
-
-            # Measure time used by model forward
-            model_begin_time = time.time()
-            model.train()
-
-            # Accumulate gradients
-            gradient_accumulation_steps = cfg.schedule.gradient_accumulation_steps
-            is_accumulating = accumulate_steps % gradient_accumulation_steps != 0
-            accumulate_steps += 1
-
-            # Train one step
-            with fabric.no_backward_sync(model, enabled=is_accumulating):
-                outputs = model(**batch)
-                loss = outputs.loss
-                metrics = getattr(outputs, "metrics", {})
-
-                # Need to divide loss by accumulation steps
-                fabric.backward(loss / gradient_accumulation_steps)
-
-                # Update trackers
-                trackers["loss"].append(float(loss))
-                trackers["lr"].append(float(optimizer.param_groups[0]["lr"]))
-                for k, v in metrics.items():
-                    trackers[f"metrics/{k}"].append(float(v))
-
-            trackers["model_time"].append(time.time() - model_begin_time)
-
-            if is_accumulating:
-                last_batch_time = time.time()
-                continue
-
-            # Check all trackers has the same length
-            assert (
-                len(set(len(v) for k, v in trackers.items() if k != "grad_norm")) == 1
-            ), "Trackers has ambiguous length"
-
-            # Perform gradient clipping
-            grad_norm = fabric.clip_gradients(
-                model,
-                optimizer,
-                max_norm=cfg.schedule.clip_grad_norm,
-                norm_type=2.0,
-                error_if_nonfinite=True,
-            )
+log = utils.RankedLogger(__name__, rank_zero_only=True)
+
+
+@utils.task_wrapper
+def train(cfg: DictConfig) -> tuple[dict, dict]:
+    """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
+    training.
+    This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
+    failure. Useful for multiruns, saving info about the crash, etc.
+    Args:
+        cfg (DictConfig): Configuration composed by Hydra.
+    Returns:
+        Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
+    """  # noqa: E501
+
+    # set seed for random number generators in pytorch, numpy and python.random
+    if cfg.get("seed"):
+        L.seed_everything(cfg.seed, workers=True)
+
+    if cfg.get("deterministic"):
+        torch.use_deterministic_algorithms(True)
+
+    log.info(f"Instantiating datamodule <{cfg.data._target_}>")
+    datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
+
+    log.info(f"Instantiating model <{cfg.model._target_}>")
+    model: LightningModule = hydra.utils.instantiate(cfg.model)
+
+    log.info("Instantiating callbacks...")
+    callbacks: list[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
+
+    log.info("Instantiating loggers...")
+    logger: list[Logger] = utils.instantiate_loggers(cfg.get("logger"))
+
+    log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
+    trainer: Trainer = hydra.utils.instantiate(
+        cfg.trainer, callbacks=callbacks, logger=logger
+    )
+
+    object_dict = {
+        "cfg": cfg,
+        "datamodule": datamodule,
+        "model": model,
+        "callbacks": callbacks,
+        "logger": logger,
+        "trainer": trainer,
+    }
+
+    if logger:
+        log.info("Logging hyperparameters!")
+        utils.log_hyperparameters(object_dict)
 
-            if torch.isnan(grad_norm) or torch.isinf(grad_norm):
-                log.warning(f"Gradient norm is {grad_norm}, skipping update")
-                optimizer.zero_grad()
-
-            # We can't average gradients across multiple steps
-            trackers["grad_norm"].append(float(grad_norm))
-
-            # Update
-            optimizer.step()
-            optimizer.zero_grad()
-            scheduler.step()
-
-            fabric.log_dict(
-                {
-                    f"train/{k}": sum(v[-gradient_accumulation_steps:])
-                    / len(v[-gradient_accumulation_steps:])
-                    for k, v in trackers.items()
-                },
-                step=global_step,
+    if cfg.get("compile"):
+        log.info("Compiling model!")
+        model = torch.compile(model)
+
+    if cfg.get("train"):
+        log.info("Starting training!")
+
+        ckpt_path = cfg.get("ckpt_path")
+
+        if ckpt_path is None:
+            ckpt_path = utils.get_latest_checkpoint(cfg.paths.ckpt_dir)
+
+        if ckpt_path is not None:
+            log.info(f"Resuming from checkpoint: {ckpt_path}")
+
+        if cfg.get("resume_weights_only"):
+            log.info("Resuming weights only!")
+            ckpt = torch.load(ckpt_path, map_location=model.device)
+            model.load_state_dict(
+                ckpt["state_dict"] if "state_dict" in ckpt else ckpt, strict=False
             )
+            ckpt_path = None
+
+        trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
 
-            # accumulate_steps = 0
-            global_step += 1
-
-            if global_step % cfg.schedule.log_interval == 0:
-                step_time = (time.time() - start_time) / cfg.schedule.log_interval
-                eta = step_time * (cfg.schedule.max_steps - global_step)
-                additional_info = [
-                    f"{k}: {sum(v[-cfg.schedule.log_interval:]) / len(v[-cfg.schedule.log_interval:]):.4f}"
-                    for k, v in trackers.items()
-                    if k != "lr"  # lr use .2e format
-                ]
-
-                log.info(
-                    f"[{global_step}/{cfg.schedule.max_steps}] "
-                    + f"step_time: {step_time:.2f}s "
-                    + f"ETA: {timedelta(seconds=round(eta))}s "
-                    f"lr: {optimizer.param_groups[0]['lr']:.2e} "
-                    + " ".join(additional_info)
-                )
-
-                # Reset trackers
-                trackers = defaultdict(list)
-
-                start_time = time.time()
-
-            if global_step % cfg.schedule.save_interval == 0:
-                fabric.save(
-                    Path(cfg.paths.checkpoint_dir) / f"step_{global_step}.ckpt",
-                    {
-                        "model": model,
-                        "optimizer": optimizer,
-                        "scheduler": scheduler.state_dict(),
-                        "global_step": global_step,
-                    },
-                )
-
-            if (
-                getattr(cfg.schedule, "eval_interval", None) is not None
-                and global_step % cfg.schedule.eval_interval == 0
-                and valid_dataloader is not None
-            ):
-                valid(model, valid_dataloader, global_step, fabric, cfg)
-
-            if global_step >= cfg.schedule.max_steps:
-                break
-
-            last_batch_time = time.time()
+    train_metrics = trainer.callback_metrics
+
+    if cfg.get("test"):
+        log.info("Starting testing!")
+        ckpt_path = trainer.checkpoint_callback.best_model_path
+        if ckpt_path == "":
+            log.warning("Best ckpt not found! Using current weights for testing...")
+            ckpt_path = cfg.get("ckpt_path")
+
+        trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
+        log.info(f"Best ckpt path: {ckpt_path}")
+
+    test_metrics = trainer.callback_metrics
+
+    # merge train and test metrics
+    metric_dict = {**train_metrics, **test_metrics}
+
+    return metric_dict, object_dict
 
 
 @hydra.main(
     version_base="1.3", config_path="./configs", config_name="llama_pretrain.yaml"
 )
-def main(cfg: DictConfig):
-    log.info(f"Config: \n{OmegaConf.to_yaml(cfg)}")
-
-    if is_flash_attn_available() is False:
-        log.warning("Flash attention is not available, using default attention")
-
-    fabric: Fabric = hydra.utils.instantiate(cfg.trainer)
-    fabric.launch()
-    log.info(f"Fabric: {fabric}")
-
-    model = hydra.utils.instantiate(cfg.model)
-    log.info(f"Model: {repr(model)}")
-
-    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
-    freeze_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
-    log.info(f"Trainable parameters: {trainable_params/1e6:.2f}M")
-    log.info(f"Freeze parameters: {freeze_params/1e6:.2f}M")
-
-    optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters())
-    scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
-    log.info(f"Optimizer: {optimizer}")
-    log.info(f"Scheduler: {scheduler}")
-
-    log.info(f"Setup fabric model & dataset")
-    model = fabric.setup_module(model)
-    optimizer = fabric.setup_optimizers(optimizer)
-
-    # Build state
-    global_step = 0
-
-    # Restore training from checkpoint
-    checkpoint_dir = Path(cfg.paths.checkpoint_dir)
-    checkpoint_dir.mkdir(parents=True, exist_ok=True)
-
-    # Alphabetically sort checkpoints
-    checkpoints = natsorted(checkpoint_dir.glob("*.ckpt"))
-    if len(checkpoints) > 0:
-        checkpoint_path = checkpoints[-1]
-
-        log.info(f"Restoring checkpoint from {checkpoint_path}")
-        remainder = fabric.load(
-            checkpoint_path,
-            {
-                "model": model,
-                "optimizer": optimizer,
-                "scheduler": scheduler,
-            },
-        )
-        global_step = remainder["global_step"]
-        log.info(f"Restored global step: {global_step}")
-
-    train_dataloader = hydra.utils.instantiate(cfg.train_dataloader)
-    log.info(f"Train Dataloader: {train_dataloader}")
-
-    valid_dataloader = None
-    if getattr(cfg, "valid_dataloader", None) is not None:
-        valid_dataloader = hydra.utils.instantiate(cfg.valid_dataloader)
-        log.info(f"Valid Dataloader: {valid_dataloader}")
-
-    train_dataloader = fabric.setup_dataloaders(train_dataloader)
-    if valid_dataloader is not None:
-        valid_dataloader = fabric.setup_dataloaders(valid_dataloader)
-
-    log.info(f"Begin training")
-
-    train(
-        model=model,
-        optimizer=optimizer,
-        scheduler=scheduler,
-        train_dataloader=train_dataloader,
-        valid_dataloader=valid_dataloader,
-        global_step=global_step,
-        fabric=fabric,
-        cfg=cfg,
-    )
+def main(cfg: DictConfig) -> Optional[float]:
+    # train the model
+    train(cfg)
 
 
 if __name__ == "__main__":

+ 21 - 0
fish_speech/utils/__init__.py

@@ -0,0 +1,21 @@
+from .braceexpand import braceexpand
+from .file import get_latest_checkpoint
+from .instantiators import instantiate_callbacks, instantiate_loggers
+from .logger import RankedLogger
+from .logging_utils import log_hyperparameters
+from .rich_utils import enforce_tags, print_config_tree
+from .utils import extras, get_metric_value, task_wrapper
+
+__all__ = [
+    "enforce_tags",
+    "extras",
+    "get_metric_value",
+    "RankedLogger",
+    "instantiate_callbacks",
+    "instantiate_loggers",
+    "log_hyperparameters",
+    "print_config_tree",
+    "task_wrapper",
+    "braceexpand",
+    "get_latest_checkpoint",
+]

+ 74 - 0
fish_speech/utils/file.py

@@ -0,0 +1,74 @@
+import os
+from pathlib import Path
+from typing import Union
+
+AUDIO_EXTENSIONS = {
+    ".mp3",
+    ".wav",
+    ".flac",
+    ".ogg",
+    ".m4a",
+    ".wma",
+    ".aac",
+    ".aiff",
+    ".aif",
+    ".aifc",
+}
+
+
+def list_files(
+    path: Union[Path, str],
+    extensions: set[str] = None,
+    recursive: bool = False,
+    sort: bool = True,
+) -> list[Path]:
+    """List files in a directory.
+
+    Args:
+        path (Path): Path to the directory.
+        extensions (set, optional): Extensions to filter. Defaults to None.
+        recursive (bool, optional): Whether to search recursively. Defaults to False.
+        sort (bool, optional): Whether to sort the files. Defaults to True.
+
+    Returns:
+        list: List of files.
+    """
+
+    if isinstance(path, str):
+        path = Path(path)
+
+    if not path.exists():
+        raise FileNotFoundError(f"Directory {path} does not exist.")
+
+    files = (
+        [
+            Path(os.path.join(root, filename))
+            for root, _, filenames in os.walk(path, followlinks=True)
+            for filename in filenames
+            if Path(os.path.join(root, filename)).is_file()
+        ]
+        if recursive
+        else [f for f in path.glob("*") if f.is_file()]
+    )
+
+    if extensions is not None:
+        files = [f for f in files if f.suffix in extensions]
+
+    if sort:
+        files = sorted(files)
+
+    return files
+
+
+def get_latest_checkpoint(path: Path | str) -> Path | None:
+    # Find the latest checkpoint
+    ckpt_dir = Path(path)
+
+    if ckpt_dir.exists() is False:
+        return None
+
+    ckpts = sorted(ckpt_dir.glob("*.ckpt"), key=os.path.getmtime)
+    if len(ckpts) == 0:
+        return None
+
+    return ckpts[-1]

+ 50 - 0
fish_speech/utils/instantiators.py

@@ -0,0 +1,50 @@
+from typing import List
+
+import hydra
+from omegaconf import DictConfig
+from pytorch_lightning import Callback
+from pytorch_lightning.loggers import Logger
+
+from .logger import RankedLogger
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
+    """Instantiates callbacks from config."""
+
+    callbacks: List[Callback] = []
+
+    if not callbacks_cfg:
+        log.warning("No callback configs found! Skipping..")
+        return callbacks
+
+    if not isinstance(callbacks_cfg, DictConfig):
+        raise TypeError("Callbacks config must be a DictConfig!")
+
+    for _, cb_conf in callbacks_cfg.items():
+        if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
+            log.info(f"Instantiating callback <{cb_conf._target_}>")
+            callbacks.append(hydra.utils.instantiate(cb_conf))
+
+    return callbacks
+
+
+def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
+    """Instantiates loggers from config."""
+
+    logger: List[Logger] = []
+
+    if not logger_cfg:
+        log.warning("No logger configs found! Skipping...")
+        return logger
+
+    if not isinstance(logger_cfg, DictConfig):
+        raise TypeError("Logger config must be a DictConfig!")
+
+    for _, lg_conf in logger_cfg.items():
+        if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
+            log.info(f"Instantiating logger <{lg_conf._target_}>")
+            logger.append(hydra.utils.instantiate(lg_conf))
+
+    return logger

+ 1 - 1
fish_speech/logger.py → fish_speech/utils/logger.py

@@ -10,7 +10,7 @@ class RankedLogger(logging.LoggerAdapter):
     def __init__(
         self,
         name: str = __name__,
-        rank_zero_only: bool = False,
+        rank_zero_only: bool = True,
         extra: Optional[Mapping[str, object]] = None,
     ) -> None:
         """Initializes a multi-GPU-friendly python command line logger that logs on all processes

+ 48 - 0
fish_speech/utils/logging_utils.py

@@ -0,0 +1,48 @@
+from lightning.pytorch.utilities import rank_zero_only
+
+from fish_speech.utils import logger as log
+
+
+@rank_zero_only
+def log_hyperparameters(object_dict: dict) -> None:
+    """Controls which config parts are saved by lightning loggers.
+
+    Additionally saves:
+    - Number of model parameters
+    """
+
+    hparams = {}
+
+    cfg = object_dict["cfg"]
+    model = object_dict["model"]
+    trainer = object_dict["trainer"]
+
+    if not trainer.logger:
+        log.warning("Logger not found! Skipping hyperparameter logging...")
+        return
+
+    hparams["model"] = cfg["model"]
+
+    # save number of model parameters
+    hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
+    hparams["model/params/trainable"] = sum(
+        p.numel() for p in model.parameters() if p.requires_grad
+    )
+    hparams["model/params/non_trainable"] = sum(
+        p.numel() for p in model.parameters() if not p.requires_grad
+    )
+
+    hparams["data"] = cfg["data"]
+    hparams["trainer"] = cfg["trainer"]
+
+    hparams["callbacks"] = cfg.get("callbacks")
+    hparams["extras"] = cfg.get("extras")
+
+    hparams["task_name"] = cfg.get("task_name")
+    hparams["tags"] = cfg.get("tags")
+    hparams["ckpt_path"] = cfg.get("ckpt_path")
+    hparams["seed"] = cfg.get("seed")
+
+    # send hparams to all loggers
+    for logger in trainer.loggers:
+        logger.log_hyperparams(hparams)

+ 96 - 0
fish_speech/utils/rich_utils.py

@@ -0,0 +1,96 @@
+from pathlib import Path
+from typing import Sequence
+
+import rich
+import rich.syntax
+import rich.tree
+from hydra.core.hydra_config import HydraConfig
+from lightning.pytorch.utilities import rank_zero_only
+from omegaconf import DictConfig, OmegaConf, open_dict
+from rich.prompt import Prompt
+
+from fish_speech.utils import logger as log
+
+
+@rank_zero_only
+def print_config_tree(
+    cfg: DictConfig,
+    print_order: Sequence[str] = (
+        "data",
+        "model",
+        "callbacks",
+        "logger",
+        "trainer",
+        "paths",
+        "extras",
+    ),
+    resolve: bool = False,
+    save_to_file: bool = False,
+) -> None:
+    """Prints content of DictConfig using Rich library and its tree structure.
+
+    Args:
+        cfg (DictConfig): Configuration composed by Hydra.
+        print_order (Sequence[str], optional): Determines in what order config components are printed.
+        resolve (bool, optional): Whether to resolve reference fields of DictConfig.
+        save_to_file (bool, optional): Whether to export config to the hydra output folder.
+    """  # noqa: E501
+
+    style = "dim"
+    tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
+
+    queue = []
+
+    # add fields from `print_order` to queue
+    for field in print_order:
+        queue.append(field) if field in cfg else log.warning(
+            f"Field '{field}' not found in config. "
+            + f"Skipping '{field}' config printing..."
+        )
+
+    # add all the other fields to queue (not specified in `print_order`)
+    for field in cfg:
+        if field not in queue:
+            queue.append(field)
+
+    # generate config tree from queue
+    for field in queue:
+        branch = tree.add(field, style=style, guide_style=style)
+
+        config_group = cfg[field]
+        if isinstance(config_group, DictConfig):
+            branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
+        else:
+            branch_content = str(config_group)
+
+        branch.add(rich.syntax.Syntax(branch_content, "yaml"))
+
+    # print config tree
+    rich.print(tree)
+
+    # save config tree to file
+    if save_to_file:
+        with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
+            rich.print(tree, file=file)
+
+
+@rank_zero_only
+def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
+    """Prompts user to input tags from command line if no tags are provided in config."""  # noqa: E501
+
+    if not cfg.get("tags"):
+        if "id" in HydraConfig().cfg.hydra.job:
+            raise ValueError("Specify tags before launching a multirun!")
+
+        log.warning("No tags provided in config. Prompting user to input tags...")
+        tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
+        tags = [t.strip() for t in tags.split(",") if t != ""]
+
+        with open_dict(cfg):
+            cfg.tags = tags
+
+        log.info(f"Tags: {cfg.tags}")
+
+    if save_to_file:
+        with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
+            rich.print(cfg.tags, file=file)

+ 114 - 0
fish_speech/utils/utils.py

@@ -0,0 +1,114 @@
+import warnings
+from importlib.util import find_spec
+from typing import Callable
+
+from omegaconf import DictConfig
+
+from .logger import RankedLogger
+from .rich_utils import enforce_tags, print_config_tree
+
+log = RankedLogger(__name__, rank_zero_only=True)
+
+
+def extras(cfg: DictConfig) -> None:
+    """Applies optional utilities before the task is started.
+
+    Utilities:
+    - Ignoring python warnings
+    - Setting tags from command line
+    - Rich config printing
+    """
+
+    # return if no `extras` config
+    if not cfg.get("extras"):
+        log.warning("Extras config not found! <cfg.extras=null>")
+        return
+
+    # disable python warnings
+    if cfg.extras.get("ignore_warnings"):
+        log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
+        warnings.filterwarnings("ignore")
+
+    # prompt user to input tags from command line if none are provided in the config
+    if cfg.extras.get("enforce_tags"):
+        log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
+        enforce_tags(cfg, save_to_file=True)
+
+    # pretty print config tree using Rich library
+    if cfg.extras.get("print_config"):
+        log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
+        print_config_tree(cfg, resolve=True, save_to_file=True)
+
+
+def task_wrapper(task_func: Callable) -> Callable:
+    """Optional decorator that controls the failure behavior when executing the task function.
+
+    This wrapper can be used to:
+    - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
+    - save the exception to a `.log` file
+    - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
+    - etc. (adjust depending on your needs)
+
+    Example:
+    ```
+    @utils.task_wrapper
+    def train(cfg: DictConfig) -> Tuple[dict, dict]:
+
+        ...
+
+        return metric_dict, object_dict
+    ```
+    """  # noqa: E501
+
+    def wrap(cfg: DictConfig):
+        # execute the task
+        try:
+            metric_dict, object_dict = task_func(cfg=cfg)
+
+        # things to do if exception occurs
+        except Exception as ex:
+            # save exception to `.log` file
+            log.exception("")
+
+            # some hyperparameter combinations might be invalid or
+            # cause out-of-memory errors so when using hparam search
+            # plugins like Optuna, you might want to disable
+            # raising the below exception to avoid multirun failure
+            raise ex
+
+        # things to always do after either success or exception
+        finally:
+            # display output dir path in terminal
+            log.info(f"Output dir: {cfg.paths.run_dir}")
+
+            # always close wandb run (even if exception occurs so multirun won't fail)
+            if find_spec("wandb"):  # check if wandb is installed
+                import wandb
+
+                if wandb.run:
+                    log.info("Closing wandb!")
+                    wandb.finish()
+
+        return metric_dict, object_dict
+
+    return wrap
+
+
+def get_metric_value(metric_dict: dict, metric_name: str) -> float:
+    """Safely retrieves value of the metric logged in LightningModule."""
+
+    if not metric_name:
+        log.info("Metric name is None! Skipping metric value retrieval...")
+        return None
+
+    if metric_name not in metric_dict:
+        raise Exception(
+            f"Metric value not found! <metric_name={metric_name}>\n"
+            "Make sure metric name logged in LightningModule is correct!\n"
+            "Make sure `optimized_metric` name in `hparams_search` config is correct!"
+        )
+
+    metric_value = metric_dict[metric_name].item()
+    log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
+
+    return metric_value

+ 29 - 0
fish_speech/utils/viz.py

@@ -0,0 +1,29 @@
+import matplotlib
+from matplotlib import pyplot as plt
+from torch import Tensor
+
+matplotlib.use("Agg")
+
+
+def plot_mel(data, titles=None):
+    fig, axes = plt.subplots(len(data), 1, squeeze=False)
+
+    if titles is None:
+        titles = [None for i in range(len(data))]
+
+    plt.tight_layout()
+
+    for i in range(len(data)):
+        mel = data[i]
+
+        if isinstance(mel, Tensor):
+            mel = mel.detach().cpu().numpy()
+
+        axes[i][0].imshow(mel, origin="lower")
+        axes[i][0].set_aspect(2.5, adjustable="box")
+        axes[i][0].set_ylim(0, mel.shape[0])
+        axes[i][0].set_title(titles[i], fontsize="medium")
+        axes[i][0].tick_params(labelsize="x-small", left=False, labelleft=False)
+        axes[i][0].set_anchor("W")
+
+    return fig

+ 6 - 0
pyrightconfig.json

@@ -0,0 +1,6 @@
+{
+    "exclude": [
+        "data",
+        "filelists"
+    ]
+}

+ 32 - 0
tools/build_vq_text.py

@@ -0,0 +1,32 @@
+from pathlib import Path
+
+from datasets import Dataset
+
+
+def parse_data(wav_dir, item):
+    text_file = (wav_dir / item["item_name"]).with_suffix(".txt")
+    text = text_file.read_text().strip()
+
+    semantic = item["semantic_audio"]
+    semantic = [f"<semantic_{x}>" for x in semantic.split(" ")]
+    semantic = " ".join(semantic)
+
+    text = f"[INST] {text} [/INST] {semantic} </s>"
+
+    return {
+        "text": text,
+    }
+
+
+if __name__ == "__main__":
+    # dataset = WenetVQDataset()
+    # dataset = list(dataset)
+    # print("Initialized dataset.")
+    dataset = Dataset.from_csv("data/cn-hubert-wenet-25hz-semantic.tsv", delimiter="\t")
+    dataset = dataset.map(
+        lambda item: parse_data(Path("data/WenetSpeech"), item), num_proc=64
+    )
+    dataset = dataset.remove_columns(["item_name", "semantic_audio"])
+    dataset = dataset.train_test_split(test_size=0.01)
+    print(dataset["test"][0])
+    dataset.push_to_hub("fishaudio/wenet-vq", private=True)