text2semantic_pretrain.yaml 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. defaults:
  2. - base
  3. - model@model.model: dual_ar_2_codebook_small
  4. - _self_
  5. project: text2semantic_pretrain_dual_ar_debug
  6. max_length: 2048
  7. # Lightning Trainer
  8. trainer:
  9. accumulate_grad_batches: 1
  10. gradient_clip_val: 1.0
  11. gradient_clip_algorithm: 'norm'
  12. max_steps: 1_000_000
  13. precision: bf16-true
  14. limit_val_batches: 10
  15. # Dataset Configuration
  16. tokenizer:
  17. _target_: transformers.AutoTokenizer.from_pretrained
  18. pretrained_model_name_or_path: fishaudio/fish-speech-1
  19. # Dataset Configuration
  20. train_dataset:
  21. _target_: fish_speech.datasets.text.AutoAugTextDataset
  22. proto_files:
  23. - data/protos/train
  24. tokenizer: ${tokenizer}
  25. max_length: ${max_length}
  26. num_codebooks: ${model.model.config.num_codebooks}
  27. use_speaker: false
  28. interactive_prob: 0.5
  29. skip_text_prob: 0.1
  30. val_dataset:
  31. _target_: fish_speech.datasets.text.AutoAugTextDataset
  32. proto_files:
  33. - data/protos/test
  34. tokenizer: ${tokenizer}
  35. max_length: ${max_length}
  36. num_codebooks: ${model.model.config.num_codebooks}
  37. use_speaker: false
  38. interactive_prob: 0.5
  39. skip_text_prob: 0.1
  40. data:
  41. _target_: fish_speech.datasets.text.TextDataModule
  42. train_dataset: ${train_dataset}
  43. val_dataset: ${val_dataset}
  44. num_workers: 4
  45. batch_size: 8
  46. tokenizer: ${tokenizer}
  47. max_length: ${max_length}
  48. # Model Configuration
  49. model:
  50. _target_: fish_speech.models.text2semantic.TextToSemantic
  51. model: {}
  52. optimizer:
  53. _target_: torch.optim.AdamW
  54. _partial_: true
  55. lr: 3e-4
  56. weight_decay: 0.01
  57. betas: [0.9, 0.95]
  58. eps: 1e-5
  59. lr_scheduler:
  60. _target_: torch.optim.lr_scheduler.LambdaLR
  61. _partial_: true
  62. lr_lambda:
  63. _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
  64. _partial_: true
  65. num_warmup_steps: 2000
  66. num_training_steps: ${trainer.max_steps}
  67. final_lr_ratio: 0.1