text2semantic_pretrain.yaml 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. defaults:
  2. - base
  3. - model@model.model: dual_ar_2_codebook_small
  4. - _self_
  5. project: text2semantic_pretrain_dual_ar_debug
  6. max_length: 2048
  7. # Lightning Trainer
  8. trainer:
  9. accumulate_grad_batches: 1
  10. gradient_clip_val: 1.0
  11. gradient_clip_algorithm: 'norm'
  12. max_steps: 1_000_000
  13. precision: bf16-true
  14. limit_val_batches: 10
  15. # Dataset Configuration
  16. tokenizer:
  17. _target_: transformers.AutoTokenizer.from_pretrained
  18. pretrained_model_name_or_path: fishaudio/fish-speech-1
  19. # Dataset Configuration
  20. train_dataset:
  21. _target_: fish_speech.datasets.text.AutoAugTextDataset
  22. proto_files:
  23. - data/protos/train
  24. tokenizer: ${tokenizer}
  25. max_length: ${max_length}
  26. num_codebooks: ${model.model.config.num_codebooks}
  27. use_speaker: false
  28. interactive_prob: 0.5
  29. val_dataset:
  30. _target_: fish_speech.datasets.text.AutoAugTextDataset
  31. proto_files:
  32. - data/protos/test
  33. tokenizer: ${tokenizer}
  34. max_length: ${max_length}
  35. num_codebooks: ${model.model.config.num_codebooks}
  36. use_speaker: false
  37. interactive_prob: 0.5
  38. data:
  39. _target_: fish_speech.datasets.text.TextDataModule
  40. train_dataset: ${train_dataset}
  41. val_dataset: ${val_dataset}
  42. num_workers: 4
  43. batch_size: 8
  44. tokenizer: ${tokenizer}
  45. max_length: ${max_length}
  46. # Model Configuration
  47. model:
  48. _target_: fish_speech.models.text2semantic.TextToSemantic
  49. model: {}
  50. optimizer:
  51. _target_: torch.optim.AdamW
  52. _partial_: true
  53. lr: 3e-4
  54. weight_decay: 0.01
  55. betas: [0.9, 0.95]
  56. eps: 1e-5
  57. lr_scheduler:
  58. _target_: torch.optim.lr_scheduler.LambdaLR
  59. _partial_: true
  60. lr_lambda:
  61. _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
  62. _partial_: true
  63. num_warmup_steps: 2000
  64. num_training_steps: ${trainer.max_steps}
  65. final_lr_ratio: 0.1