text2semantic_finetune.yaml 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. defaults:
  2. - base
  3. - _self_
  4. project: text2semantic_finetune_dual_ar
  5. max_length: 4096
  6. pretrained_ckpt_path: checkpoints/fish-speech-1.5
  7. # Lightning Trainer
  8. trainer:
  9. accumulate_grad_batches: 1
  10. gradient_clip_val: 1.0
  11. gradient_clip_algorithm: "norm"
  12. max_steps: 10000
  13. precision: bf16-true
  14. limit_val_batches: 10
  15. val_check_interval: 100
  16. # strategy:
  17. # find_unused_parameters: true
  18. # static_graph: true
  19. # Dataset Configuration
  20. tokenizer:
  21. _target_: fish_speech.tokenizer.FishTokenizer
  22. model_path: ${pretrained_ckpt_path}/tokenizer.tiktoken
  23. # Dataset Configuration
  24. train_dataset:
  25. _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
  26. proto_files:
  27. - data/protos
  28. tokenizer: ${tokenizer}
  29. causal: true
  30. max_length: ${max_length}
  31. use_speaker: false
  32. interactive_prob: 0.7
  33. val_dataset:
  34. _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionIterableDataset
  35. proto_files:
  36. - data/protos
  37. tokenizer: ${tokenizer}
  38. causal: true
  39. max_length: ${max_length}
  40. use_speaker: false
  41. interactive_prob: 0.7
  42. data:
  43. _target_: fish_speech.datasets.semantic.SemanticDataModule
  44. train_dataset: ${train_dataset}
  45. val_dataset: ${val_dataset}
  46. num_workers: 4
  47. batch_size: 4
  48. tokenizer: ${tokenizer}
  49. max_length: ${max_length}
  50. # Model Configuration
  51. model:
  52. _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
  53. model:
  54. _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
  55. path: ${pretrained_ckpt_path}
  56. load_weights: true
  57. max_length: ${max_length}
  58. lora_config: null
  59. optimizer:
  60. _target_: torch.optim.AdamW
  61. _partial_: true
  62. lr: 1e-4
  63. weight_decay: 0
  64. betas: [0.9, 0.95]
  65. eps: 1e-5
  66. lr_scheduler:
  67. _target_: torch.optim.lr_scheduler.LambdaLR
  68. _partial_: true
  69. lr_lambda:
  70. _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
  71. _partial_: true
  72. num_warmup_steps: 10
  73. # Callbacks
  74. callbacks:
  75. model_checkpoint:
  76. every_n_train_steps: ${trainer.val_check_interval}