text2semantic_finetune.yaml 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. defaults:
  2. - base
  3. - _self_
  4. project: text2semantic_finetune_dual_ar
  5. max_length: 4096
  6. pretrained_ckpt_path: checkpoints/fish-speech-1.2-sft
  7. # Lightning Trainer
  8. trainer:
  9. accumulate_grad_batches: 1
  10. gradient_clip_val: 1.0
  11. gradient_clip_algorithm: "norm"
  12. max_steps: 1000
  13. precision: bf16-true
  14. limit_val_batches: 10
  15. val_check_interval: 100
  16. # Dataset Configuration
  17. tokenizer:
  18. _target_: transformers.AutoTokenizer.from_pretrained
  19. pretrained_model_name_or_path: ${pretrained_ckpt_path}
  20. # Dataset Configuration
  21. train_dataset:
  22. _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
  23. proto_files:
  24. - data/protos
  25. tokenizer: ${tokenizer}
  26. causal: true
  27. max_length: ${max_length}
  28. use_speaker: false
  29. interactive_prob: 0.7
  30. val_dataset:
  31. _target_: fish_speech.datasets.semantic.AutoTextSemanticInstructionDataset
  32. proto_files:
  33. - data/protos
  34. tokenizer: ${tokenizer}
  35. causal: true
  36. max_length: ${max_length}
  37. use_speaker: false
  38. interactive_prob: 0.7
  39. data:
  40. _target_: fish_speech.datasets.semantic.SemanticDataModule
  41. train_dataset: ${train_dataset}
  42. val_dataset: ${val_dataset}
  43. num_workers: 4
  44. batch_size: 8
  45. tokenizer: ${tokenizer}
  46. max_length: ${max_length}
  47. # Model Configuration
  48. model:
  49. _target_: fish_speech.models.text2semantic.lit_module.TextToSemantic
  50. model:
  51. _target_: fish_speech.models.text2semantic.llama.BaseTransformer.from_pretrained
  52. path: ${pretrained_ckpt_path}
  53. load_weights: true
  54. max_length: ${max_length}
  55. lora_config: null
  56. optimizer:
  57. _target_: torch.optim.AdamW
  58. _partial_: true
  59. lr: 1e-4
  60. weight_decay: 0
  61. betas: [0.9, 0.95]
  62. eps: 1e-5
  63. lr_scheduler:
  64. _target_: torch.optim.lr_scheduler.LambdaLR
  65. _partial_: true
  66. lr_lambda:
  67. _target_: fish_speech.scheduler.get_constant_schedule_with_warmup_lr_lambda
  68. _partial_: true
  69. num_warmup_steps: 10
  70. # Callbacks
  71. callbacks:
  72. model_checkpoint:
  73. every_n_train_steps: ${trainer.val_check_interval}