text2semantic_finetune.yaml 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. defaults:
  2. - base
  3. - model@model.model: dual_ar_2_codebook_small
  4. - _self_
  5. project: text2semantic_finetune_dual_ar
  6. max_length: 2048
  7. ckpt_path: checkpoints/text2semantic-sft-medium-v1.1-4k.pth
  8. resume_weights_only: true
  9. # Lightning Trainer
  10. trainer:
  11. accumulate_grad_batches: 1
  12. gradient_clip_val: 1.0
  13. gradient_clip_algorithm: "norm"
  14. max_steps: 1000
  15. precision: bf16-true
  16. limit_val_batches: 10
  17. val_check_interval: 100
  18. # Dataset Configuration
  19. tokenizer:
  20. _target_: transformers.AutoTokenizer.from_pretrained
  21. pretrained_model_name_or_path: fishaudio/fish-speech-1
  22. # Dataset Configuration
  23. train_dataset:
  24. _target_: fish_speech.datasets.text.AutoAugTextDataset
  25. proto_files:
  26. - data/protos
  27. tokenizer: ${tokenizer}
  28. max_length: ${max_length}
  29. num_codebooks: ${model.model.config.num_codebooks}
  30. use_speaker: 0.5
  31. interactive_prob: 0.7
  32. val_dataset:
  33. _target_: fish_speech.datasets.text.AutoAugTextDataset
  34. proto_files:
  35. - data/protos
  36. tokenizer: ${tokenizer}
  37. max_length: ${max_length}
  38. num_codebooks: ${model.model.config.num_codebooks}
  39. use_speaker: 0.5
  40. interactive_prob: 0.7
  41. data:
  42. _target_: fish_speech.datasets.text.TextDataModule
  43. train_dataset: ${train_dataset}
  44. val_dataset: ${val_dataset}
  45. num_workers: 4
  46. batch_size: 8
  47. tokenizer: ${tokenizer}
  48. max_length: ${max_length}
  49. # Model Configuration
  50. model:
  51. _target_: fish_speech.models.text2semantic.TextToSemantic
  52. model: {}
  53. optimizer:
  54. _target_: torch.optim.AdamW
  55. _partial_: true
  56. lr: 1e-5
  57. weight_decay: 0
  58. betas: [0.9, 0.95]
  59. eps: 1e-5
  60. lr_scheduler:
  61. _target_: torch.optim.lr_scheduler.LambdaLR
  62. _partial_: true
  63. lr_lambda:
  64. _target_: fish_speech.scheduler.get_cosine_schedule_with_warmup_lr_lambda
  65. _partial_: true
  66. num_warmup_steps: 0.1
  67. num_training_steps: ${trainer.max_steps}
  68. # Callbacks
  69. callbacks:
  70. model_checkpoint:
  71. every_n_train_steps: ${trainer.val_check_interval}