| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586 |
- defaults:
- - base
- - _self_
- project: text2semantic_finetune_dual_ar
- max_length: 4096
- pretrained_ckpt_path: checkpoints/openaudio-s1-mini
- # Lightning Trainer
- trainer:
- accumulate_grad_batches: 1
- gradient_clip_val: 1.0
- gradient_clip_algorithm: "norm"
- max_steps: 10000
- precision: bf16-true
- limit_val_batches: 10
- val_check_interval: 100
- # strategy:
- # find_unused_parameters: true
- # static_graph: true
- # Dataset Configuration
- tokenizer:
- _target_: fish_speech.tokenizer.FishTokenizer
- model_path: ${pretrained_ckpt_path}/tokenizer.tiktoken
- # Dataset Configuration
- train_dataset:
- _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
- proto_files:
- - data/protos
- tokenizer: ${tokenizer}
- causal: true
- max_length: ${max_length}
- use_speaker: false
- interactive_prob: 0.7
- val_dataset:
- _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
- proto_files:
- - data/protos
- tokenizer: ${tokenizer}
- causal: true
- max_length: ${max_length}
- use_speaker: false
- interactive_prob: 0.7
- data:
- _target_: fish_speech.datasets.semantic.SemanticDataModule
- train_dataset: ${train_dataset}
- val_dataset: ${val_dataset}
- num_workers: 4
- batch_size: 4
- tokenizer: ${tokenizer}
- max_length: ${max_length}
- # Model Configuration
- model:
- _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
- model:
- _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
- path: ${pretrained_ckpt_path}
- load_weights: true
- max_length: ${max_length}
- lora_config: null
- optimizer:
- _target_: torch.optim.AdamW
- _partial_: true
- lr: 1e-4
- weight_decay: 0
- betas: [0.9, 0.95]
- eps: 1e-5
- lr_scheduler:
- _target_: torch.optim.lr_scheduler.LambdaLR
- _partial_: true
- lr_lambda:
- _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
- _partial_: true
- num_warmup_steps: 10
- # Callbacks
- callbacks:
- model_checkpoint:
- every_n_train_steps: ${trainer.val_check_interval}
|