text2semantic_sft.yaml 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. defaults:
  2. - base
  3. - model@model.model: dual_ar_8_codebook_small
  4. - _self_
  5. project: text2semantic_sft_medium_dual_ar
  6. max_length: 4096
  7. ckpt_path: results/text2semantic_pretrain_medium_dual_ar/checkpoints/step_000060000.ckpt
  8. resume_weights_only: true
  9. # Lightning Trainer
  10. trainer:
  11. accumulate_grad_batches: 1
  12. gradient_clip_val: 1.0
  13. gradient_clip_algorithm: 'norm'
  14. max_steps: 10_000
  15. precision: bf16-true
  16. limit_val_batches: 10
  17. val_check_interval: 500
  18. # Dataset Configuration
  19. tokenizer:
  20. _target_: transformers.AutoTokenizer.from_pretrained
  21. pretrained_model_name_or_path: fishaudio/speech-lm-v1
  22. # Dataset Configuration
  23. train_dataset:
  24. _target_: fish_speech.datasets.text.AutoAugTextDataset
  25. use_data_server: false
  26. proto_files:
  27. - data/protos/sft/train_Genshin.protos
  28. - data/protos/sft/sft.protos
  29. tokenizer: ${tokenizer}
  30. max_length: ${max_length}
  31. num_codebooks: ${model.model.config.num_codebooks}
  32. use_speaker: false
  33. phones_prob: 0.5
  34. interactive_prob: 0.5
  35. val_dataset:
  36. _target_: fish_speech.datasets.text.AutoAugTextDataset
  37. use_data_server: false
  38. proto_files:
  39. - data/protos/sft/val_Genshin.protos
  40. tokenizer: ${tokenizer}
  41. max_length: ${max_length}
  42. num_codebooks: ${model.model.config.num_codebooks}
  43. use_speaker: false
  44. phones_prob: 0.5
  45. interactive_prob: 0.5
  46. data:
  47. _target_: fish_speech.datasets.text.TextDataModule
  48. train_dataset: ${train_dataset}
  49. val_dataset: ${val_dataset}
  50. num_workers: 4
  51. batch_size: 8
  52. tokenizer: ${tokenizer}
  53. max_length: ${max_length}
  54. # Model Configuration
  55. model:
  56. _target_: fish_speech.models.text2semantic.TextToSemantic
  57. model: {}
  58. optimizer:
  59. _target_: torch.optim.AdamW
  60. _partial_: true
  61. lr: 4e-5
  62. weight_decay: 0
  63. betas: [0.9, 0.95]
  64. eps: 1e-5
  65. lr_scheduler:
  66. _target_: torch.optim.lr_scheduler.LambdaLR
  67. _partial_: true
  68. lr_lambda:
  69. _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
  70. _partial_: true
  71. num_warmup_steps: 100
  72. num_training_steps: ${trainer.max_steps}
  73. final_lr_ratio: 0
  74. callbacks:
  75. model_checkpoint:
  76. every_n_train_steps: 1000
  77. save_top_k: 10