vqgan_pretrain.yaml 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. defaults:
  2. - base
  3. - _self_
  4. project: vq-gan-pretrain
  5. # Lightning Trainer
  6. trainer:
  7. accelerator: gpu
  8. devices: auto
  9. precision: bf16-mixed
  10. max_steps: 1_000_000
  11. val_check_interval: 5000
  12. strategy:
  13. find_unused_parameters: true
  14. sample_rate: 44100
  15. hop_length: 512
  16. num_mels: 128
  17. n_fft: 2048
  18. win_length: 2048
  19. # Dataset Configuration
  20. train_dataset:
  21. _target_: torch.utils.data.ConcatDataset
  22. datasets:
  23. - _target_: fish_speech.datasets.vqgan.VQGANDataset
  24. filelist: data/gigaspeech/vq_train_filelist.txt
  25. sample_rate: ${sample_rate}
  26. hop_length: ${hop_length}
  27. slice_frames: 512
  28. - _target_: fish_speech.datasets.vqgan.VQGANDataset
  29. filelist: data/sft/vq_train_filelist.txt
  30. sample_rate: ${sample_rate}
  31. hop_length: ${hop_length}
  32. slice_frames: 512
  33. val_dataset:
  34. _target_: fish_speech.datasets.vqgan.VQGANDataset
  35. filelist: data/sft/vq_val_filelist.txt
  36. sample_rate: ${sample_rate}
  37. hop_length: ${hop_length}
  38. data:
  39. _target_: fish_speech.datasets.vqgan.VQGANDataModule
  40. train_dataset: ${train_dataset}
  41. val_dataset: ${val_dataset}
  42. num_workers: 4
  43. batch_size: 32
  44. val_batch_size: 32
  45. # Model Configuration
  46. model:
  47. _target_: fish_speech.models.vqgan.VQGAN
  48. sampling_rate: ${sample_rate}
  49. weight_adv: 0.2
  50. weight_vq: 1.0
  51. weight_mel: 1.0
  52. freeze_encoder: false
  53. encoder:
  54. _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
  55. input_channels: ${num_mels}
  56. residual_channels: 768
  57. residual_layers: 20
  58. dilation_cycle: 4
  59. quantizer:
  60. _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
  61. input_dim: 768
  62. n_codebooks: 1
  63. n_groups: 2
  64. levels: [8, 5, 5, 5]
  65. decoder:
  66. _target_: fish_speech.models.vqgan.modules.wavenet.WaveNet
  67. output_channels: ${num_mels}
  68. residual_channels: 768
  69. residual_layers: 20
  70. dilation_cycle: 4
  71. condition_channels: 768
  72. discriminator:
  73. _target_: fish_speech.models.vqgan.modules.discriminator.Discriminator
  74. vocoder:
  75. _target_: fish_speech.models.vqgan.modules.firefly.FireflyBase
  76. ckpt_path: null # You may download the pretrained vocoder and set the path here
  77. encode_mel_transform:
  78. _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
  79. sample_rate: ${sample_rate}
  80. n_fft: ${n_fft}
  81. hop_length: ${hop_length}
  82. win_length: ${win_length}
  83. n_mels: ${num_mels}
  84. f_min: 0.0
  85. f_max: 8000.0
  86. gt_mel_transform:
  87. _target_: fish_speech.utils.spectrogram.LogMelSpectrogram
  88. sample_rate: ${sample_rate}
  89. n_fft: ${n_fft}
  90. hop_length: ${hop_length}
  91. win_length: ${win_length}
  92. n_mels: ${num_mels}
  93. optimizer:
  94. _target_: torch.optim.AdamW
  95. _partial_: true
  96. lr: 1e-4
  97. betas: [0.8, 0.99]
  98. eps: 1e-5
  99. weight_decay: 0.01
  100. lr_scheduler:
  101. _target_: torch.optim.lr_scheduler.LambdaLR
  102. _partial_: true
  103. lr_lambda:
  104. _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
  105. _partial_: true
  106. num_warmup_steps: 100
  107. num_training_steps: ${trainer.max_steps}
  108. final_lr_ratio: 0
  109. callbacks:
  110. model_summary:
  111. _target_: lightning.pytorch.callbacks.ModelSummary
  112. max_depth: 1
  113. model_checkpoint:
  114. every_n_train_steps: ${trainer.val_check_interval}
  115. grad_norm_monitor:
  116. sub_module:
  117. - encoder
  118. - decoder
  119. - quantizer
  120. - discriminator