|
|
@@ -0,0 +1,116 @@
|
|
|
+paths:
|
|
|
+ run_dir: results/finetune
|
|
|
+ checkpoint_dir: ${paths.run_dir}/checkpoints
|
|
|
+
|
|
|
+hydra:
|
|
|
+ run:
|
|
|
+ dir: ${paths.run_dir}
|
|
|
+
|
|
|
+trainer:
|
|
|
+ _target_: lightning.fabric.Fabric
|
|
|
+ accelerator: gpu
|
|
|
+ strategy:
|
|
|
+ _target_: lightning.fabric.strategies.DDPStrategy
|
|
|
+ static_graph: true
|
|
|
+ num_nodes: 1
|
|
|
+ devices: 8
|
|
|
+ precision: bf16-mixed
|
|
|
+ loggers:
|
|
|
+ _target_: pytorch_lightning.loggers.TensorBoardLogger
|
|
|
+ save_dir: ${paths.run_dir}
|
|
|
+ name: tensorboard
|
|
|
+ version: null
|
|
|
+
|
|
|
+model:
|
|
|
+ _target_: transformers.AutoModelForCausalLM.from_pretrained
|
|
|
+ pretrained_model_name_or_path: fishaudio/speech-lm-300m
|
|
|
+ revision: text-pretrain-10k
|
|
|
+
|
|
|
+tokenizer:
|
|
|
+ _target_: transformers.AutoTokenizer.from_pretrained
|
|
|
+ pretrained_model_name_or_path: fishaudio/speech-lm-300m
|
|
|
+ revision: text-pretrain-10k
|
|
|
+
|
|
|
+# This is a 200 billion seen token schedule
|
|
|
+schedule:
|
|
|
+ max_length: 1024
|
|
|
+ batch_size: 16 # 128 * 4 = 512
|
|
|
+ micro_batch_size: 8
|
|
|
+ max_steps: 100000
|
|
|
+ save_interval: 5000
|
|
|
+ log_interval: 10
|
|
|
+ gradient_accumulation_steps: "${eval: ${schedule.batch_size} // ${schedule.micro_batch_size}}"
|
|
|
+ clip_grad_norm: 1.0
|
|
|
+
|
|
|
+train_dataset:
|
|
|
+ _target_: speech_lm.datasets.cultura_x.InterleaveDataset
|
|
|
+ datasets:
|
|
|
+ - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
|
|
|
+ lang: 'en'
|
|
|
+ - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
|
|
|
+ lang: 'zh'
|
|
|
+ - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
|
|
|
+ lang: 'ja'
|
|
|
+ - _target_: speech_lm.datasets.cultura_x.CulturaXDataset
|
|
|
+ repo: fishaudio/wenet-vq
|
|
|
+ files:
|
|
|
+ - data/train-00000-of-00018-b5a82c6054c6acca.parquet
|
|
|
+ - data/train-00001-of-00018-82467b3e0669c2be.parquet
|
|
|
+ - data/train-00002-of-00018-d50ed8c218a1f183.parquet
|
|
|
+ - data/train-00003-of-00018-15d666053eade100.parquet
|
|
|
+ - data/train-00004-of-00018-01868cb8408e012b.parquet
|
|
|
+ - data/train-00005-of-00018-e766a0b54b1fd08b.parquet
|
|
|
+ - data/train-00006-of-00018-c79fad54ea8a0b8d.parquet
|
|
|
+ - data/train-00007-of-00018-e4155011a7081a1d.parquet
|
|
|
+ - data/train-00008-of-00018-8ba319f5af359d15.parquet
|
|
|
+ - data/train-00009-of-00018-9c9e984a6565b2c3.parquet
|
|
|
+ - data/train-00010-of-00018-7af80a80e5aa1e54.parquet
|
|
|
+ - data/train-00011-of-00018-2ab91221787a84a3.parquet
|
|
|
+ - data/train-00012-of-00018-4d477812eea5d298.parquet
|
|
|
+ - data/train-00013-of-00018-faf87b68b1ab4a15.parquet
|
|
|
+ - data/train-00014-of-00018-7f6bbd9bcb4cbb55.parquet
|
|
|
+ - data/train-00015-of-00018-d630fe4a488b9f51.parquet
|
|
|
+ - data/train-00016-of-00018-969a4d5dc04d2764.parquet
|
|
|
+ - data/train-00017-of-00018-bbfd09175809d1fe.parquet
|
|
|
+ probabilities: [0.2, 0.2, 0.2, 0.4]
|
|
|
+ seed: 42
|
|
|
+
|
|
|
+train_dataloader:
|
|
|
+ _target_: torch.utils.data.DataLoader
|
|
|
+ dataset: ${train_dataset}
|
|
|
+ batch_size: ${schedule.micro_batch_size}
|
|
|
+ num_workers: 8
|
|
|
+ collate_fn:
|
|
|
+ _target_: speech_lm.datasets.cultura_x.CulutreXCollator
|
|
|
+ tokenizer: ${tokenizer}
|
|
|
+ max_length: ${schedule.max_length}
|
|
|
+
|
|
|
+valid_dataloader:
|
|
|
+ _target_: torch.utils.data.DataLoader
|
|
|
+ dataset:
|
|
|
+ _target_: speech_lm.datasets.cultura_x.CulturaXDataset
|
|
|
+ repo: fishaudio/wenet-vq
|
|
|
+ files:
|
|
|
+ - data/test-00000-of-00001-685250c116f5d321.parquet
|
|
|
+ batch_size: ${schedule.micro_batch_size}
|
|
|
+ num_workers: 1
|
|
|
+ collate_fn:
|
|
|
+ _target_: speech_lm.datasets.cultura_x.CulutreXCollator
|
|
|
+ tokenizer: ${tokenizer}
|
|
|
+ max_length: ${schedule.max_length}
|
|
|
+
|
|
|
+optimizer:
|
|
|
+ _target_: torch.optim.AdamW
|
|
|
+ lr: 1e-4
|
|
|
+ weight_decay: 0.1
|
|
|
+ betas: [0.9, 0.95]
|
|
|
+ eps: 1e-5
|
|
|
+
|
|
|
+scheduler:
|
|
|
+ _target_: torch.optim.lr_scheduler.LambdaLR
|
|
|
+ lr_lambda:
|
|
|
+ _target_: speech_lm.scheduler.get_cosine_schedule_with_warmup_lr_lambda
|
|
|
+ _partial_: true
|
|
|
+ num_warmup_steps: 2000
|
|
|
+ num_training_steps: ${schedule.max_steps}
|
|
|
+ final_lr_ratio: 0.1
|