| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- defaults:
- - base
- - _self_
- project: vq-gan-finetune
- ckpt_path: checkpoints/vq-gan-group-fsq-2x1024.pth
- resume_weights_only: true
- # Lightning Trainer
- trainer:
- accelerator: gpu
- devices: auto
- precision: bf16-mixed
- max_steps: 100_000
- val_check_interval: 5000
- strategy:
- find_unused_parameters: true
- sample_rate: 44100
- hop_length: 512
- num_mels: 128
- n_fft: 2048
- win_length: 2048
- # Dataset Configuration
- train_dataset:
- _target_: fish_speech.datasets.vqgan.VQGANDataset
- filelist: data/vq_train_filelist.txt
- sample_rate: ${sample_rate}
- hop_length: ${hop_length}
- slice_frames: 512
- val_dataset:
- _target_: fish_speech.datasets.vqgan.VQGANDataset
- filelist: data/vq_val_filelist.txt
- sample_rate: ${sample_rate}
- hop_length: ${hop_length}
- data:
- _target_: fish_speech.datasets.vqgan.VQGANDataModule
- train_dataset: ${train_dataset}
- val_dataset: ${val_dataset}
- num_workers: 4
- batch_size: 16
- val_batch_size: 16
- # Model Configuration
- model:
- _target_: fish_speech.models.vqgan.VQGAN
- sampling_rate: ${sample_rate}
- weight_adv: 0.2
- weight_vq: 1.0
- weight_mel: 1.0
- # Important: Set the freeze_encoder to true to only train the decoder
- freeze_encoder: true
- encoder:
- _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
- input_channels: ${num_mels}
- residual_channels: 768
- residual_layers: 20
- dilation_cycle: 4
-
- quantizer:
- _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
- input_dim: 768
- n_codebooks: 1
- n_groups: 2
- levels: [8, 5, 5, 5]
- decoder:
- _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
- output_channels: ${num_mels}
- residual_channels: 768
- residual_layers: 20
- dilation_cycle: 4
- condition_channels: 768
-
- discriminator:
- _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
- vocoder:
- _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
- ckpt_path: null # You may download the pretrained vocoder and set the path here
- encode_mel_transform:
- _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
- sample_rate: ${sample_rate}
- n_fft: ${n_fft}
- hop_length: ${hop_length}
- win_length: ${win_length}
- n_mels: ${num_mels}
- f_min: 0.0
- f_max: 8000.0
- gt_mel_transform:
- _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
- sample_rate: ${sample_rate}
- n_fft: ${n_fft}
- hop_length: ${hop_length}
- win_length: ${win_length}
- n_mels: ${num_mels}
- optimizer:
- _target_: torch.optim.AdamW
- _partial_: true
- lr: 4e-5
- betas: [0.8, 0.99]
- eps: 1e-5
- weight_decay: 0.01
- lr_scheduler:
- _target_: torch.optim.lr_scheduler.LambdaLR
- _partial_: true
- lr_lambda:
- _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
- _partial_: true
- num_warmup_steps: 0
- num_training_steps: ${trainer.max_steps}
- final_lr_ratio: 0
- callbacks:
- model_summary:
- _target_: lightning.pytorch.callbacks.ModelSummary
- max_depth: 1
- model_checkpoint:
- every_n_train_steps: ${trainer.val_check_interval}
- grad_norm_monitor:
- sub_module:
- - encoder
- - decoder
- - quantizer
- - discriminator
|