Sfoglia il codice sorgente

Updata llama pretain config (max 100 val epoch)

Lengyue 2 anni fa
parent
commit
66219b20cf
3 ha cambiato i file con 62 aggiunte e 72 eliminazioni
  1. 56 67
      fish_speech/configs/llama_pretrain.yaml
  2. 5 5
      fish_speech/datasets/text.py
  3. 1 0
      requirements.txt

+ 56 - 67
fish_speech/configs/llama_pretrain.yaml

@@ -1,84 +1,73 @@
-paths:
-  run_dir: results/pretrain
-  checkpoint_dir: ${paths.run_dir}/checkpoints
+defaults:
+  - base
+  - _self_
 
-hydra:
-  run:
-    dir: ${paths.run_dir}
+project: llama_pretrain
 
-trainer:
-  _target_: lightning.fabric.Fabric
-  accelerator: gpu
-  strategy:
-    _target_: lightning.fabric.strategies.DDPStrategy
-    static_graph: true
-  num_nodes: 8
-  devices: 8
-  precision: bf16-mixed
-  loggers:
-    _target_: pytorch_lightning.loggers.TensorBoardLogger
-    save_dir: ${paths.run_dir}
-    name: tensorboard
-    version: null
+# Say we want a 3 trillion seen token schedule
+# 3e12 / 1024 / 512 / 8 = 715255
+# But we use a 100k steps schedule here to save time
+# This is a 400 billion seen token schedule:
+# 1024 * 512 * 8 * 100000 = 419_430_400_000
 
-model:
-  _target_: transformers.AutoModelForCausalLM.from_pretrained
-  pretrained_model_name_or_path: fishaudio/speech-lm-300m
-  revision: init
+# Lightning Trainer
+trainer:
+  accumulate_grad_batches: 64
+  gradient_clip_val: 1.0
+  gradient_clip_algorithm: 'norm'
+  num_nodes: 1
+  limit_val_batches: 100 # 100 batches for validation
 
+# Dataset Configuration
 tokenizer:
   _target_: transformers.AutoTokenizer.from_pretrained
   pretrained_model_name_or_path: fishaudio/speech-lm-300m
   revision: init
 
-# Say we want a 3 trillion seen token schedule
-# 3e12 / 1024 / 512 / 8 = 715255
-# But we use a 100k steps schedule here to save time
-# This is a 300 billion seen token schedule
-schedule:
-  max_length: 1024
-  batch_size: 64  # 128 * 4 = 512
-  micro_batch_size: 8
-  max_steps: 100000
-  save_interval: 5000
-  log_interval: 10
-  gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
-  clip_grad_norm: 1.0
-
+# Dataset Configuration
 dataset:
-  _target_: fish_speech.datasets.cultura_x.InterleaveDataset
+  _target_: fish_speech.datasets.text.InterleaveDataset
   datasets:
-    - _target_: fish_speech.datasets.cultura_x.CulturaXDataset
-      lang: 'en'
-    - _target_: fish_speech.datasets.cultura_x.CulturaXDataset
-      lang: 'zh'
-    - _target_: fish_speech.datasets.cultura_x.CulturaXDataset
-      lang: 'ja'
+    - _target_: fish_speech.datasets.text.TextDataset
+      prefix: 'en/'
+    - _target_: fish_speech.datasets.text.TextDataset
+      prefix: 'zh/'
+    - _target_: fish_speech.datasets.text.TextDataset
+      prefix: 'ja/'
   probabilities: [0.4, 0.3, 0.3]
   seed: 42
 
-train_dataloader:
-  _target_: torch.utils.data.DataLoader
-  dataset: ${dataset}
-  batch_size: ${schedule.micro_batch_size}
-  num_workers: 8
-  collate_fn:
-    _target_: fish_speech.datasets.cultura_x.CulutreXCollator
-    tokenizer: ${tokenizer}
-    max_length: ${schedule.max_length}
+data:
+  _target_: fish_speech.datasets.text.TextDataModule
+  train_dataset: ${dataset}
+  val_dataset: ${dataset}
+  num_workers: 4
+  batch_size: 8
+  tokenizer: ${tokenizer}
 
-optimizer:
-  _target_: torch.optim.AdamW
-  lr: 3e-4
-  weight_decay: 0.1
-  betas: [0.9, 0.95]
-  eps: 1e-5
+# Model Configuration
+model:
+  _target_: fish_speech.models.text2semantic.TextToSemantic
+
+  model:
+    _target_: transformers.AutoModelForCausalLM.from_pretrained
+    pretrained_model_name_or_path: fishaudio/speech-lm-300m
+    revision: init
+
+  optimizer:
+    _target_: torch.optim.AdamW
+    _partial_: true
+    lr: 3e-4
+    weight_decay: 0.1
+    betas: [0.9, 0.95]
+    eps: 1e-5
 
-scheduler:
-  _target_: torch.optim.lr_scheduler.LambdaLR
-  lr_lambda:
-    _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+  lr_scheduler:
+    _target_: torch.optim.lr_scheduler.LambdaLR
     _partial_: true
-    num_warmup_steps: 2000
-    num_training_steps: ${schedule.max_steps}
-    final_lr_ratio: 0.1
+    lr_lambda:
+      _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
+      _partial_: true
+      num_warmup_steps: 2000
+      num_training_steps: ${trainer.max_steps}
+      final_lr_ratio: 0.1

+ 5 - 5
fish_speech/datasets/text.py

@@ -9,6 +9,7 @@ import pyarrow.parquet as pq
 from datasets.download.streaming_download_manager import xopen
 from huggingface_hub import HfApi
 from lightning import LightningDataModule
+from lightning.pytorch.utilities.exceptions import MisconfigurationException
 from torch.distributed import get_rank, get_world_size, is_initialized
 from torch.utils.data import DataLoader, IterableDataset, get_worker_info
 from transformers import AutoTokenizer
@@ -39,7 +40,9 @@ class TextDataset(IterableDataset):
 
         if prefix is not None:
             files = HfApi().list_repo_files(repo, repo_type="dataset")
-            files = [f for f in files if f.startswith(prefix)]
+            files = [
+                f for f in files if f.startswith(prefix) and f.endswith(".parquet")
+            ]
             log.info(f"Found {len(files)} files in {repo} with prefix {prefix}")
         else:
             if isinstance(files, str):
@@ -162,7 +165,7 @@ class TextDataModule(LightningDataModule):
     def __init__(
         self,
         train_dataset: Union[TextDataset, InterleaveDataset],
-        val_dataset: Optional[Union[TextDataset, InterleaveDataset]] = None,
+        val_dataset: Union[TextDataset, InterleaveDataset],
         batch_size: int = 32,
         tokenizer: AutoTokenizer = None,
         max_length: int = 1024,
@@ -186,9 +189,6 @@ class TextDataModule(LightningDataModule):
         )
 
     def val_dataloader(self):
-        if self.val_dataset is None:
-            return None
-
         return DataLoader(
             self.val_dataset,
             batch_size=self.batch_size,

+ 1 - 0
requirements.txt

@@ -9,3 +9,4 @@ natsort>=8.4.0
 einops>=0.7.0
 librosa>=0.10.1
 vector-quantize-pytorch>=1.9.18
+rich>=13.5.3