text2semantic_sft_medium.yaml 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. defaults:
  2. - base
  3. - _self_
  4. project: text2semantic_sft_medium_delay
  5. max_length: 4096
  6. use_delay_pattern: false
  7. ckpt_path: results/text2semantic_pretrain_medium_4_in_8_codebooks/checkpoints/step_000100000.ckpt
  8. resume_weights_only: true
  9. # Lightning Trainer
  10. trainer:
  11. accumulate_grad_batches: 1
  12. gradient_clip_val: 1.0
  13. gradient_clip_algorithm: 'norm'
  14. max_steps: 10_000
  15. precision: bf16-true
  16. limit_val_batches: 10
  17. val_check_interval: 500
  18. # Dataset Configuration
  19. tokenizer:
  20. _target_: transformers.AutoTokenizer.from_pretrained
  21. pretrained_model_name_or_path: fishaudio/speech-lm-v1
  22. # Dataset Configuration
  23. train_dataset:
  24. _target_: fish_speech.datasets.text.AutoAugTextDataset
  25. use_data_server: false
  26. proto_files:
  27. - data/protos/sft/train_Genshin.protos
  28. - data/protos/sft/sft.protos
  29. tokenizer: ${tokenizer}
  30. max_length: ${max_length}
  31. num_codebooks: ${model.model.config.num_codebooks}
  32. use_speaker: true
  33. phones_prob: 0.5
  34. interactive_prob: 0.5
  35. use_delay_pattern: ${use_delay_pattern}
  36. val_dataset:
  37. _target_: fish_speech.datasets.text.AutoAugTextDataset
  38. use_data_server: false
  39. proto_files:
  40. - data/protos/sft/val_Genshin.protos
  41. tokenizer: ${tokenizer}
  42. max_length: ${max_length}
  43. num_codebooks: ${model.model.config.num_codebooks}
  44. use_speaker: true
  45. phones_prob: 0.5
  46. interactive_prob: 0.5
  47. use_delay_pattern: ${use_delay_pattern}
  48. data:
  49. _target_: fish_speech.datasets.text.TextDataModule
  50. train_dataset: ${train_dataset}
  51. val_dataset: ${val_dataset}
  52. num_workers: 4
  53. batch_size: 16
  54. tokenizer: ${tokenizer}
  55. max_length: ${max_length}
  56. # Model Configuration
  57. model:
  58. _target_: fish_speech.models.text2semantic.TextToSemantic
  59. model:
  60. # ~ 130M parameters, for debug purpose
  61. _target_: fish_speech.models.text2semantic.llama.Transformer
  62. config:
  63. _target_: fish_speech.models.text2semantic.llama.ModelArgs
  64. max_seq_len: 4096
  65. vocab_size: 36408
  66. n_layer: 24
  67. n_head: 16
  68. dim: 1024
  69. rope_base: 10000
  70. norm_eps: 1e-5
  71. num_in_codebooks: 4 # input codebook size
  72. num_codebooks: 8 # output codebook size
  73. codebook_size: 264 # codebook size 256 + 2 special tokens
  74. dropout: 0
  75. neft_alpha: 0
  76. optimizer:
  77. _target_: bitsandbytes.optim.AdamW8bit
  78. _partial_: true
  79. lr: 4e-5
  80. weight_decay: 0
  81. betas: [0.9, 0.95]
  82. eps: 1e-5
  83. lr_scheduler:
  84. _target_: torch.optim.lr_scheduler.LambdaLR
  85. _partial_: true
  86. lr_lambda:
  87. _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
  88. _partial_: true
  89. num_warmup_steps: 100
  90. num_training_steps: ${trainer.max_steps}
  91. final_lr_ratio: 0
  92. callbacks:
  93. model_checkpoint:
  94. every_n_train_steps: 1000
  95. save_top_k: 10