hubert_vq_diffusion.yaml 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. defaults:
  2. - base
  3. - _self_
  4. project: hubert_vq_diffusion
  5. # Lightning Trainer
  6. trainer:
  7. accelerator: gpu
  8. devices: 4
  9. strategy: ddp_find_unused_parameters_true
  10. gradient_clip_val: 1.0
  11. gradient_clip_algorithm: 'norm'
  12. precision: 16-mixed
  13. max_steps: 1_000_000
  14. val_check_interval: 1000
  15. sample_rate: 44100
  16. hop_length: 512
  17. num_mels: 128
  18. n_fft: 2048
  19. win_length: 2048
  20. # Dataset Configuration
  21. train_dataset:
  22. _target_: fish_speech.datasets.vqgan.VQGANDataset
  23. filelist: data/vq_train_filelist.txt
  24. sample_rate: ${sample_rate}
  25. hop_length: ${hop_length}
  26. slice_frames: 512
  27. val_dataset:
  28. _target_: fish_speech.datasets.vqgan.VQGANDataset
  29. filelist: data/vq_val_filelist.txt
  30. sample_rate: ${sample_rate}
  31. hop_length: ${hop_length}
  32. data:
  33. _target_: fish_speech.datasets.vqgan.VQGANDataModule
  34. train_dataset: ${train_dataset}
  35. val_dataset: ${val_dataset}
  36. num_workers: 8
  37. batch_size: 32
  38. val_batch_size: 4
  39. # Model Configuration
  40. model:
  41. _target_: fish_speech.models.vq_diffusion.lit_module.VQDiffusion
  42. sample_rate: ${sample_rate}
  43. hop_length: ${hop_length}
  44. text_encoder:
  45. _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
  46. in_channels: 1024
  47. out_channels: 128
  48. hidden_channels: 192
  49. hidden_channels_ffn: 768
  50. n_heads: 2
  51. n_layers: 6
  52. kernel_size: 1
  53. dropout: 0.1
  54. use_vae: false
  55. gin_channels: 512
  56. speaker_cond_layer: 2
  57. vq_encoder:
  58. _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
  59. in_channels: 1024
  60. vq_channels: 1024
  61. codebook_size: 2048
  62. downsample: 2
  63. kmeans_ckpt: results/hubert-vq-pretrain/kmeans.pt
  64. speaker_encoder:
  65. _target_: fish_speech.models.vqgan.modules.encoders.SpeakerEncoder
  66. in_channels: 128
  67. hidden_channels: 192
  68. out_channels: 512
  69. num_heads: 2
  70. num_layers: 4
  71. p_dropout: 0.1
  72. denoiser:
  73. _target_: fish_speech.models.vq_diffusion.convnext_1d.ConvNext1DModel
  74. in_channels: 256
  75. out_channels: 128
  76. intermediate_dim: 512
  77. # condition_dim: 128
  78. mlp_dim: 2048
  79. num_layers: 20
  80. dilation_cycle_length: 2
  81. time_embedding_type: "positional"
  82. vocoder:
  83. _target_: fish_speech.models.vq_diffusion.adamos.ADaMoSHiFiGANV1
  84. mel_transform:
  85. _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
  86. sample_rate: ${sample_rate}
  87. n_fft: ${n_fft}
  88. hop_length: ${hop_length}
  89. win_length: ${win_length}
  90. n_mels: ${num_mels}
  91. f_min: 40
  92. f_max: 16000
  93. optimizer:
  94. _target_: torch.optim.AdamW
  95. _partial_: true
  96. lr: 1e-4
  97. betas: [0.9, 0.999]
  98. eps: 1e-5
  99. lr_scheduler:
  100. _target_: torch.optim.lr_scheduler.LambdaLR
  101. _partial_: true
  102. lr_lambda:
  103. _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
  104. _partial_: true
  105. num_warmup_steps: 0
  106. num_training_steps: ${trainer.max_steps}
  107. final_lr_ratio: 0.05
  108. callbacks:
  109. grad_norm_monitor:
  110. sub_module:
  111. - vq_encoder
  112. - text_encoder
  113. - speaker_encoder
  114. - denoiser