text2semantic_pretrain.yaml 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. defaults:
  2. - base
  3. - _self_
  4. project: text2semantic_400m_pretrain
  5. max_length: 1024
  6. # Lightning Trainer
  7. trainer:
  8. accumulate_grad_batches: 2
  9. gradient_clip_val: 1.0
  10. gradient_clip_algorithm: 'norm'
  11. max_steps: 1_000_000
  12. precision: bf16-true
  13. limit_val_batches: 10
  14. # Dataset Configuration
  15. tokenizer:
  16. _target_: transformers.AutoTokenizer.from_pretrained
  17. pretrained_model_name_or_path: fishaudio/speech-lm-v1
  18. # Dataset Configuration
  19. train_dataset:
  20. _target_: fish_speech.datasets.text.AutoAugTextDataset
  21. tokenizer: ${tokenizer}
  22. max_length: ${max_length}
  23. val_dataset:
  24. _target_: fish_speech.datasets.text.AutoAugTextDataset
  25. tokenizer: ${tokenizer}
  26. max_length: ${max_length}
  27. data:
  28. _target_: fish_speech.datasets.text.TextDataModule
  29. train_dataset: ${train_dataset}
  30. val_dataset: ${val_dataset}
  31. num_workers: 4
  32. batch_size: 16
  33. tokenizer: ${tokenizer}
  34. max_length: ${max_length}
  35. # Model Configuration
  36. model:
  37. _target_: fish_speech.models.text2semantic.TextToSemantic
  38. model:
  39. # ~ 130M parameters, for debug purpose
  40. _target_: fish_speech.models.text2semantic.llama.Transformer
  41. config:
  42. _target_: fish_speech.models.text2semantic.llama.ModelArgs
  43. max_seq_len: 4096
  44. vocab_size: 36408
  45. n_layer: 24
  46. n_head: 16
  47. dim: 1024
  48. rope_base: 10000
  49. norm_eps: 1e-5
  50. num_codebooks: 4 # single codebook
  51. codebook_size: 168 # codebook size 160 + 2 special tokens
  52. optimizer:
  53. _target_: torch.optim.AdamW
  54. _partial_: true
  55. lr: 3e-4
  56. weight_decay: 0.1
  57. betas: [0.9, 0.95]
  58. eps: 1e-5
  59. lr_scheduler:
  60. _target_: torch.optim.lr_scheduler.LambdaLR
  61. _partial_: true
  62. lr_lambda:
  63. _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
  64. _partial_: true
  65. num_warmup_steps: 2000
  66. num_training_steps: ${trainer.max_steps}
  67. final_lr_ratio: 0.1