text2semantic.yaml 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. defaults:
  2. - base
  3. - _self_
  4. project: text2semantic_400m
  5. # Lightning Trainer
  6. trainer:
  7. accumulate_grad_batches: 2
  8. gradient_clip_val: 1.0
  9. gradient_clip_algorithm: 'norm'
  10. max_steps: 1_000_000
  11. precision: bf16-true
  12. # Dataset Configuration
  13. tokenizer:
  14. _target_: transformers.AutoTokenizer.from_pretrained
  15. pretrained_model_name_or_path: 01-ai/Yi-34B
  16. padding_side: right
  17. truncation_side: right
  18. # Dataset Configuration
  19. train_dataset:
  20. _target_: fish_speech.datasets.text.StreamTextDataset
  21. repo: fishaudio/cn-hubert-25hz-vq
  22. prefix: 'data/train'
  23. val_dataset:
  24. _target_: fish_speech.datasets.text.StreamTextDataset
  25. repo: fishaudio/cn-hubert-25hz-vq
  26. prefix: 'data/test'
  27. data:
  28. _target_: fish_speech.datasets.text.TextDataModule
  29. train_dataset: ${train_dataset}
  30. val_dataset: ${val_dataset}
  31. num_workers: 4
  32. batch_size: 32
  33. tokenizer: ${tokenizer}
  34. # Model Configuration
  35. model:
  36. _target_: fish_speech.models.text2semantic.TextToSemantic
  37. model:
  38. # ~ 130M parameters, for debug purpose
  39. _target_: fish_speech.models.text2semantic.modules.FishSpeechTransformer
  40. vocab_size: 64000
  41. codebook_size: 1032 # 1024 + 2 (bos, eos), make it divisible by 8
  42. num_codebooks: 1
  43. hidden_size: 1024
  44. nhead: 16
  45. num_encoder_layers: 12
  46. num_decoder_layers: 12
  47. optimizer:
  48. _target_: torch.optim.AdamW
  49. _partial_: true
  50. lr: 1e-4
  51. weight_decay: 0.1
  52. betas: [0.9, 0.95]
  53. eps: 1e-5
  54. lr_scheduler:
  55. _target_: torch.optim.lr_scheduler.LambdaLR
  56. _partial_: true
  57. lr_lambda:
  58. _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
  59. _partial_: true
  60. num_warmup_steps: 2000
  61. num_training_steps: ${trainer.max_steps}
  62. final_lr_ratio: 0.1