vqgan_finetune.yaml 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. defaults:
  2. - base
  3. - _self_
  4. project: vq-gan-finetune
  5. ckpt_path: checkpoints/vq-gan-group-fsq-2x1024.pth
  6. resume_weights_only: true
  7. # Lightning Trainer
  8. trainer:
  9. accelerator: gpu
  10. devices: auto
  11. precision: bf16-mixed
  12. max_steps: 100_000
  13. val_check_interval: 5000
  14. strategy:
  15. find_unused_parameters: true
  16. sample_rate: 44100
  17. hop_length: 512
  18. num_mels: 128
  19. n_fft: 2048
  20. win_length: 2048
  21. # Dataset Configuration
  22. train_dataset:
  23. _target_: fish_speech.datasets.vqgan.VQGANDataset
  24. filelist: data/vq_train_filelist.txt
  25. sample_rate: ${sample_rate}
  26. hop_length: ${hop_length}
  27. slice_frames: 512
  28. val_dataset:
  29. _target_: fish_speech.datasets.vqgan.VQGANDataset
  30. filelist: data/vq_val_filelist.txt
  31. sample_rate: ${sample_rate}
  32. hop_length: ${hop_length}
  33. data:
  34. _target_: fish_speech.datasets.vqgan.VQGANDataModule
  35. train_dataset: ${train_dataset}
  36. val_dataset: ${val_dataset}
  37. num_workers: 4
  38. batch_size: 16
  39. val_batch_size: 16
  40. # Model Configuration
  41. model:
  42. _target_: fish_speech.models.vqgan.VQGAN
  43. sampling_rate: ${sample_rate}
  44. weight_adv: 0.2
  45. weight_vq: 1.0
  46. weight_mel: 1.0
  47. # Important: Set the freeze_encoder to true to only train the decoder
  48. freeze_encoder: true
  49. encoder:
  50. _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
  51. input_channels: ${num_mels}
  52. residual_channels: 768
  53. residual_layers: 20
  54. dilation_cycle: 4
  55. quantizer:
  56. _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
  57. input_dim: 768
  58. n_codebooks: 1
  59. n_groups: 2
  60. levels: [8, 5, 5, 5]
  61. decoder:
  62. _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
  63. output_channels: ${num_mels}
  64. residual_channels: 768
  65. residual_layers: 20
  66. dilation_cycle: 4
  67. condition_channels: 768
  68. discriminator:
  69. _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
  70. vocoder:
  71. _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
  72. ckpt_path: null # You may download the pretrained vocoder and set the path here
  73. encode_mel_transform:
  74. _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
  75. sample_rate: ${sample_rate}
  76. n_fft: ${n_fft}
  77. hop_length: ${hop_length}
  78. win_length: ${win_length}
  79. n_mels: ${num_mels}
  80. f_min: 0.0
  81. f_max: 8000.0
  82. gt_mel_transform:
  83. _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
  84. sample_rate: ${sample_rate}
  85. n_fft: ${n_fft}
  86. hop_length: ${hop_length}
  87. win_length: ${win_length}
  88. n_mels: ${num_mels}
  89. optimizer:
  90. _target_: torch.optim.AdamW
  91. _partial_: true
  92. lr: 4e-5
  93. betas: [0.8, 0.99]
  94. eps: 1e-5
  95. weight_decay: 0.01
  96. lr_scheduler:
  97. _target_: torch.optim.lr_scheduler.LambdaLR
  98. _partial_: true
  99. lr_lambda:
  100. _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
  101. _partial_: true
  102. num_warmup_steps: 0
  103. num_training_steps: ${trainer.max_steps}
  104. final_lr_ratio: 0
  105. callbacks:
  106. model_summary:
  107. _target_: lightning.pytorch.callbacks.ModelSummary
  108. max_depth: 1
  109. model_checkpoint:
  110. every_n_train_steps: ${trainer.val_check_interval}
  111. grad_norm_monitor:
  112. sub_module:
  113. - encoder
  114. - decoder
  115. - quantizer
  116. - discriminator