|
@@ -1,84 +1,73 @@
|
|
|
-paths:
|
|
|
|
|
- run_dir: results/pretrain
|
|
|
|
|
- checkpoint_dir: ${paths.run_dir}/checkpoints
|
|
|
|
|
|
|
+defaults:
|
|
|
|
|
+ - base
|
|
|
|
|
+ - _self_
|
|
|
|
|
|
|
|
-hydra:
|
|
|
|
|
- run:
|
|
|
|
|
- dir: ${paths.run_dir}
|
|
|
|
|
|
|
+project: llama_pretrain
|
|
|
|
|
|
|
|
-trainer:
|
|
|
|
|
- _target_: lightning.fabric.Fabric
|
|
|
|
|
- accelerator: gpu
|
|
|
|
|
- strategy:
|
|
|
|
|
- _target_: lightning.fabric.strategies.DDPStrategy
|
|
|
|
|
- static_graph: true
|
|
|
|
|
- num_nodes: 8
|
|
|
|
|
- devices: 8
|
|
|
|
|
- precision: bf16-mixed
|
|
|
|
|
- loggers:
|
|
|
|
|
- _target_: pytorch_lightning.loggers.TensorBoardLogger
|
|
|
|
|
- save_dir: ${paths.run_dir}
|
|
|
|
|
- name: tensorboard
|
|
|
|
|
- version: null
|
|
|
|
|
|
|
+# Say we want a 3 trillion seen token schedule
|
|
|
|
|
+# 3e12 / 1024 / 512 / 8 = 715255
|
|
|
|
|
+# But we use a 100k steps schedule here to save time
|
|
|
|
|
+# This is a 400 billion seen token schedule:
|
|
|
|
|
+# 1024 * 512 * 8 * 100000 = 419_430_400_000
|
|
|
|
|
|
|
|
-model:
|
|
|
|
|
- _target_: transformers.AutoModelForCausalLM.from_pretrained
|
|
|
|
|
- pretrained_model_name_or_path: fishaudio/speech-lm-300m
|
|
|
|
|
- revision: init
|
|
|
|
|
|
|
+# Lightning Trainer
|
|
|
|
|
+trainer:
|
|
|
|
|
+ accumulate_grad_batches: 64
|
|
|
|
|
+ gradient_clip_val: 1.0
|
|
|
|
|
+ gradient_clip_algorithm: 'norm'
|
|
|
|
|
+ num_nodes: 1
|
|
|
|
|
+ limit_val_batches: 100 # 100 batches for validation
|
|
|
|
|
|
|
|
|
|
+# Dataset Configuration
|
|
|
tokenizer:
|
|
tokenizer:
|
|
|
_target_: transformers.AutoTokenizer.from_pretrained
|
|
_target_: transformers.AutoTokenizer.from_pretrained
|
|
|
pretrained_model_name_or_path: fishaudio/speech-lm-300m
|
|
pretrained_model_name_or_path: fishaudio/speech-lm-300m
|
|
|
revision: init
|
|
revision: init
|
|
|
|
|
|
|
|
-# Say we want a 3 trillion seen token schedule
|
|
|
|
|
-# 3e12 / 1024 / 512 / 8 = 715255
|
|
|
|
|
-# But we use a 100k steps schedule here to save time
|
|
|
|
|
-# This is a 300 billion seen token schedule
|
|
|
|
|
-schedule:
|
|
|
|
|
- max_length: 1024
|
|
|
|
|
- batch_size: 64 # 128 * 4 = 512
|
|
|
|
|
- micro_batch_size: 8
|
|
|
|
|
- max_steps: 100000
|
|
|
|
|
- save_interval: 5000
|
|
|
|
|
- log_interval: 10
|
|
|
|
|
- gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
|
|
|
|
|
- clip_grad_norm: 1.0
|
|
|
|
|
-
|
|
|
|
|
|
|
+# Dataset Configuration
|
|
|
dataset:
|
|
dataset:
|
|
|
- _target_: fish_speech.datasets.cultura_x.InterleaveDataset
|
|
|
|
|
|
|
+ _target_: fish_speech.datasets.text.InterleaveDataset
|
|
|
datasets:
|
|
datasets:
|
|
|
- - _target_: fish_speech.datasets.cultura_x.CulturaXDataset
|
|
|
|
|
- lang: 'en'
|
|
|
|
|
- - _target_: fish_speech.datasets.cultura_x.CulturaXDataset
|
|
|
|
|
- lang: 'zh'
|
|
|
|
|
- - _target_: fish_speech.datasets.cultura_x.CulturaXDataset
|
|
|
|
|
- lang: 'ja'
|
|
|
|
|
|
|
+ - _target_: fish_speech.datasets.text.TextDataset
|
|
|
|
|
+ prefix: 'en/'
|
|
|
|
|
+ - _target_: fish_speech.datasets.text.TextDataset
|
|
|
|
|
+ prefix: 'zh/'
|
|
|
|
|
+ - _target_: fish_speech.datasets.text.TextDataset
|
|
|
|
|
+ prefix: 'ja/'
|
|
|
probabilities: [0.4, 0.3, 0.3]
|
|
probabilities: [0.4, 0.3, 0.3]
|
|
|
seed: 42
|
|
seed: 42
|
|
|
|
|
|
|
|
-train_dataloader:
|
|
|
|
|
- _target_: torch.utils.data.DataLoader
|
|
|
|
|
- dataset: ${dataset}
|
|
|
|
|
- batch_size: ${schedule.micro_batch_size}
|
|
|
|
|
- num_workers: 8
|
|
|
|
|
- collate_fn:
|
|
|
|
|
- _target_: fish_speech.datasets.cultura_x.CulutreXCollator
|
|
|
|
|
- tokenizer: ${tokenizer}
|
|
|
|
|
- max_length: ${schedule.max_length}
|
|
|
|
|
|
|
+data:
|
|
|
|
|
+ _target_: fish_speech.datasets.text.TextDataModule
|
|
|
|
|
+ train_dataset: ${dataset}
|
|
|
|
|
+ val_dataset: ${dataset}
|
|
|
|
|
+ num_workers: 4
|
|
|
|
|
+ batch_size: 8
|
|
|
|
|
+ tokenizer: ${tokenizer}
|
|
|
|
|
|
|
|
-optimizer:
|
|
|
|
|
- _target_: torch.optim.AdamW
|
|
|
|
|
- lr: 3e-4
|
|
|
|
|
- weight_decay: 0.1
|
|
|
|
|
- betas: [0.9, 0.95]
|
|
|
|
|
- eps: 1e-5
|
|
|
|
|
|
|
+# Model Configuration
|
|
|
|
|
+model:
|
|
|
|
|
+ _target_: fish_speech.models.text2semantic.TextToSemantic
|
|
|
|
|
+
|
|
|
|
|
+ model:
|
|
|
|
|
+ _target_: transformers.AutoModelForCausalLM.from_pretrained
|
|
|
|
|
+ pretrained_model_name_or_path: fishaudio/speech-lm-300m
|
|
|
|
|
+ revision: init
|
|
|
|
|
+
|
|
|
|
|
+ optimizer:
|
|
|
|
|
+ _target_: torch.optim.AdamW
|
|
|
|
|
+ _partial_: true
|
|
|
|
|
+ lr: 3e-4
|
|
|
|
|
+ weight_decay: 0.1
|
|
|
|
|
+ betas: [0.9, 0.95]
|
|
|
|
|
+ eps: 1e-5
|
|
|
|
|
|
|
|
-scheduler:
|
|
|
|
|
- _target_: torch.optim.lr_scheduler.LambdaLR
|
|
|
|
|
- lr_lambda:
|
|
|
|
|
- _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
|
|
|
|
|
|
|
+ lr_scheduler:
|
|
|
|
|
+ _target_: torch.optim.lr_scheduler.LambdaLR
|
|
|
_partial_: true
|
|
_partial_: true
|
|
|
- num_warmup_steps: 2000
|
|
|
|
|
- num_training_steps: ${schedule.max_steps}
|
|
|
|
|
- final_lr_ratio: 0.1
|
|
|
|
|
|
|
+ lr_lambda:
|
|
|
|
|
+ _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
|
|
|
|
|
+ _partial_: true
|
|
|
|
|
+ num_warmup_steps: 2000
|
|
|
|
|
+ num_training_steps: ${trainer.max_steps}
|
|
|
|
|
+ final_lr_ratio: 0.1
|