| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- defaults:
- - base
- - _self_
- project: text2semantic_sft_medium_delay
- max_length: 4096
- use_delay_pattern: false
- ckpt_path: results/text2semantic_pretrain_medium_4_in_8_codebooks/checkpoints/step_000100000.ckpt
- resume_weights_only: true
- # Lightning Trainer
- trainer:
- accumulate_grad_batches: 1
- gradient_clip_val: 1.0
- gradient_clip_algorithm: 'norm'
- max_steps: 10_000
- precision: bf16-true
- limit_val_batches: 10
- val_check_interval: 500
- # Dataset Configuration
- tokenizer:
- _target_: transformers.AutoTokenizer.from_pretrained
- pretrained_model_name_or_path: fishaudio/speech-lm-v1
- # Dataset Configuration
- train_dataset:
- _target_: fish_speech.datasets.text.AutoAugTextDataset
- use_data_server: false
- proto_files:
- - data/protos/sft/train_Genshin.protos
- - data/protos/sft/sft.protos
- tokenizer: ${tokenizer}
- max_length: ${max_length}
- num_codebooks: ${model.model.config.num_codebooks}
- use_speaker: true
- phones_prob: 0.5
- interactive_prob: 0.5
- use_delay_pattern: ${use_delay_pattern}
- val_dataset:
- _target_: fish_speech.datasets.text.AutoAugTextDataset
- use_data_server: false
- proto_files:
- - data/protos/sft/val_Genshin.protos
- tokenizer: ${tokenizer}
- max_length: ${max_length}
- num_codebooks: ${model.model.config.num_codebooks}
- use_speaker: true
- phones_prob: 0.5
- interactive_prob: 0.5
- use_delay_pattern: ${use_delay_pattern}
- data:
- _target_: fish_speech.datasets.text.TextDataModule
- train_dataset: ${train_dataset}
- val_dataset: ${val_dataset}
- num_workers: 4
- batch_size: 16
- tokenizer: ${tokenizer}
- max_length: ${max_length}
- # Model Configuration
- model:
- _target_: fish_speech.models.text2semantic.TextToSemantic
- model:
- # ~ 130M parameters, for debug purpose
- _target_: fish_speech.models.text2semantic.llama.Transformer
- config:
- _target_: fish_speech.models.text2semantic.llama.ModelArgs
- max_seq_len: 4096
- vocab_size: 36408
- n_layer: 24
- n_head: 16
- dim: 1024
- rope_base: 10000
- norm_eps: 1e-5
- num_in_codebooks: 4 # input codebook size
- num_codebooks: 8 # output codebook size
- codebook_size: 264 # codebook size 256 + 2 special tokens
- dropout: 0
- neft_alpha: 0
- optimizer:
- _target_: bitsandbytes.optim.AdamW8bit
- _partial_: true
- lr: 4e-5
- weight_decay: 0
- betas: [0.9, 0.95]
- eps: 1e-5
- lr_scheduler:
- _target_: torch.optim.lr_scheduler.LambdaLR
- _partial_: true
- lr_lambda:
- _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
- _partial_: true
- num_warmup_steps: 100
- num_training_steps: ${trainer.max_steps}
- final_lr_ratio: 0
- callbacks:
- model_checkpoint:
- every_n_train_steps: 1000
- save_top_k: 10
|