hubert_vq_diffusion.yaml 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. defaults:
  2. - base
  3. - _self_
  4. project: hubert_vq_diffusion
  5. # Lightning Trainer
  6. trainer:
  7. accelerator: gpu
  8. devices: 4
  9. strategy: ddp_find_unused_parameters_true
  10. gradient_clip_val: 1.0
  11. gradient_clip_algorithm: 'norm'
  12. precision: 16-mixed
  13. max_steps: 1_000_000
  14. val_check_interval: 5000
  15. sample_rate: 44100
  16. hop_length: 512
  17. num_mels: 128
  18. n_fft: 2048
  19. win_length: 2048
  20. # Dataset Configuration
  21. train_dataset:
  22. _target_: fish_speech.datasets.vqgan.VQGANDataset
  23. filelist: data/filelist.split.train
  24. sample_rate: ${sample_rate}
  25. hop_length: ${hop_length}
  26. slice_frames: 512
  27. val_dataset:
  28. _target_: fish_speech.datasets.vqgan.VQGANDataset
  29. filelist: data/filelist.split.valid
  30. sample_rate: ${sample_rate}
  31. hop_length: ${hop_length}
  32. data:
  33. _target_: fish_speech.datasets.vqgan.VQGANDataModule
  34. train_dataset: ${train_dataset}
  35. val_dataset: ${val_dataset}
  36. num_workers: 8
  37. batch_size: 32
  38. val_batch_size: 4
  39. # Model Configuration
  40. model:
  41. _target_: fish_speech.models.vq_diffusion.lit_module.VQDiffusion
  42. sample_rate: ${sample_rate}
  43. hop_length: ${hop_length}
  44. text_encoder:
  45. _target_: fish_speech.models.vqgan.modules.encoders.TextEncoder
  46. in_channels: 128
  47. out_channels: 128
  48. hidden_channels: 192
  49. hidden_channels_ffn: 768
  50. n_heads: 2
  51. n_layers: 6
  52. kernel_size: 1
  53. dropout: 0.1
  54. use_vae: false
  55. gin_channels: 512
  56. speaker_cond_layer: 0
  57. vq_encoder:
  58. _target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
  59. in_channels: 128
  60. vq_channels: 128
  61. codebook_size: 16384
  62. downsample: 1
  63. speaker_encoder:
  64. _target_: fish_speech.models.vqgan.modules.encoders.SpeakerEncoder
  65. in_channels: 128
  66. hidden_channels: 192
  67. out_channels: 128
  68. num_heads: 2
  69. num_layers: 4
  70. p_dropout: 0.1
  71. denoiser:
  72. _target_: fish_speech.models.vq_diffusion.convnext_1d.ConvNext1DModel
  73. in_channels: 256
  74. out_channels: 128
  75. intermediate_dim: 512
  76. # condition_dim: 128
  77. mlp_dim: 2048
  78. num_layers: 20
  79. dilation_cycle_length: 2
  80. time_embedding_type: "positional"
  81. vocoder:
  82. _target_: fish_speech.models.vq_diffusion.adamos.ADaMoSHiFiGANV1
  83. mel_transform:
  84. _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
  85. sample_rate: ${sample_rate}
  86. n_fft: ${n_fft}
  87. hop_length: ${hop_length}
  88. win_length: ${win_length}
  89. n_mels: ${num_mels}
  90. f_min: 40
  91. f_max: 16000
  92. feature_mel_transform:
  93. _target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
  94. sample_rate: 32000
  95. n_fft: 2048
  96. hop_length: 640
  97. win_length: 2048
  98. n_mels: 128
  99. optimizer:
  100. _target_: torch.optim.AdamW
  101. _partial_: true
  102. lr: 1e-4
  103. betas: [0.9, 0.999]
  104. eps: 1e-5
  105. lr_scheduler:
  106. _target_: torch.optim.lr_scheduler.LambdaLR
  107. _partial_: true
  108. lr_lambda:
  109. _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
  110. _partial_: true
  111. num_warmup_steps: 0
  112. num_training_steps: ${trainer.max_steps}
  113. final_lr_ratio: 0.05
  114. callbacks:
  115. grad_norm_monitor:
  116. sub_module:
  117. - vq_encoder
  118. - text_encoder
  119. - speaker_encoder
  120. - denoiser