vq_diffusion.yaml 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. defaults:
  2. - base
  3. - _self_
  4. project: vq_diffusion
  5. # Lightning Trainer
  6. trainer:
  7. accelerator: gpu
  8. devices: 4
  9. strategy: ddp_find_unused_parameters_true
  10. gradient_clip_val: 1.0
  11. gradient_clip_algorithm: 'norm'
  12. precision: 16-mixed
  13. max_steps: 300_000
  14. val_check_interval: 5000
  15. sample_rate: 24000
  16. hop_length: 256
  17. num_mels: 100
  18. n_fft: 1024
  19. win_length: 1024
  20. # Dataset Configuration
  21. train_dataset:
  22. _target_: fish_speech.datasets.vqgan.VQGANDataset
  23. filelist: data/filelist.split.train
  24. sample_rate: ${sample_rate}
  25. hop_length: ${hop_length}
  26. slice_frames: 512
  27. val_dataset:
  28. _target_: fish_speech.datasets.vqgan.VQGANDataset
  29. filelist: data/filelist.split.valid
  30. sample_rate: ${sample_rate}
  31. hop_length: ${hop_length}
  32. data:
  33. _target_: fish_speech.datasets.vqgan.VQGANDataModule
  34. train_dataset: ${train_dataset}
  35. val_dataset: ${val_dataset}
  36. num_workers: 8
  37. batch_size: 32
  38. val_batch_size: 4
  39. # Model Configuration
  40. model:
  41. _target_: fish_speech.models.vq_diffusion.lit_module.VQDiffusion
  42. sample_rate: ${sample_rate}
  43. hop_length: ${hop_length}
  44. speaker_use_feats: true
  45. downsample:
  46. _target_: fish_speech.models.vq_diffusion.lit_module.ConvDownSample
  47. dims: [128, 512, 128]
  48. kernel_sizes: [3, 3]
  49. strides: [2, 2]
  50. text_encoder:
  51. _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
  52. in_channels: 128
  53. out_channels: 128
  54. hidden_channels: 192
  55. hidden_channels_ffn: 768
  56. n_heads: 2
  57. n_layers: 6
  58. kernel_size: 1
  59. dropout: 0.1
  60. use_vae: false
  61. gin_channels: 512
  62. speaker_cond_layer: 0
  63. vq_encoder:
  64. _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
  65. in_channels: 128
  66. vq_channels: 128
  67. codebook_size: 4096
  68. downsample: 1
  69. speaker_encoder:
  70. _target_: fish_speech.models.vqgan.modules.encoders.SpeakerEncoder
  71. in_channels: 128
  72. hidden_channels: 192
  73. out_channels: 128
  74. num_heads: 2
  75. num_layers: 4
  76. p_dropout: 0.1
  77. denoiser:
  78. _target_: fish_speech.models.vq_diffusion.wavenet.WaveNet
  79. d_encoder: 128
  80. mel_channels: 100
  81. residual_channels: 512
  82. residual_layers: 20
  83. vocoder:
  84. _target_: fish_speech.models.vq_diffusion.bigvgan.BigVGAN
  85. mel_transform:
  86. _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
  87. sample_rate: ${sample_rate}
  88. n_fft: ${n_fft}
  89. hop_length: ${hop_length}
  90. win_length: ${win_length}
  91. n_mels: ${num_mels}
  92. f_min: 0
  93. f_max: 12000
  94. feature_mel_transform:
  95. _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
  96. sample_rate: 32000
  97. n_fft: 2048
  98. hop_length: 320
  99. win_length: 2048
  100. n_mels: 128
  101. optimizer:
  102. _target_: torch.optim.AdamW
  103. _partial_: true
  104. lr: 1e-4
  105. betas: [0.9, 0.999]
  106. eps: 1e-5
  107. lr_scheduler:
  108. _target_: torch.optim.lr_scheduler.LambdaLR
  109. _partial_: true
  110. lr_lambda:
  111. _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
  112. _partial_: true
  113. num_warmup_steps: 0
  114. num_training_steps: ${trainer.max_steps}
  115. final_lr_ratio: 0.05
  116. callbacks:
  117. grad_norm_monitor:
  118. sub_module:
  119. - vq_encoder
  120. - text_encoder
  121. - speaker_encoder
  122. - denoiser