base.yaml 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. # Base configuration for training a model
  2. paths:
  3. run_dir: results/${project}
  4. ckpt_dir: ${paths.run_dir}/checkpoints
  5. hydra:
  6. run:
  7. dir: ${paths.run_dir}
  8. # Lightning Trainer
  9. trainer:
  10. _target_: lightning.pytorch.trainer.Trainer
  11. default_root_dir: ${paths.run_dir}
  12. accelerator: gpu
  13. num_nodes: 1
  14. devices: auto
  15. strategy:
  16. _target_: lightning.pytorch.strategies.DDPStrategy
  17. process_group_backend: nccl # This should be override when training on windows
  18. precision: bf16-mixed
  19. # disable validation by epoch end
  20. check_val_every_n_epoch: null
  21. val_check_interval: 5000
  22. max_steps: 100_000
  23. # Use torch.backends.cudnn.benchmark to speed up training
  24. benchmark: true
  25. # Callbacks
  26. callbacks:
  27. model_checkpoint:
  28. _target_: lightning.pytorch.callbacks.ModelCheckpoint
  29. dirpath: ${paths.ckpt_dir}
  30. filename: "step_{step:09d}"
  31. save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
  32. save_top_k: 5 # save 5 latest checkpoints
  33. monitor: step # use step to monitor checkpoints
  34. mode: max # save the latest checkpoint with the highest global_step
  35. every_n_epochs: null # don't save checkpoints by epoch end
  36. every_n_train_steps: 5000 # save checkpoints every 5000 steps
  37. auto_insert_metric_name: false
  38. model_summary:
  39. _target_: lightning.pytorch.callbacks.ModelSummary
  40. max_depth: 2 # the maximum depth of layer nesting that the summary will include
  41. learning_rate_monitor:
  42. _target_: lightning.pytorch.callbacks.LearningRateMonitor
  43. logging_interval: step
  44. log_momentum: false
  45. grad_norm_monitor:
  46. _target_: fish_speech.callbacks.GradNormMonitor
  47. norm_type: 2
  48. logging_interval: step
  49. # Logger
  50. logger:
  51. tensorboard:
  52. _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
  53. save_dir: "${paths.run_dir}/tensorboard/"
  54. name: null
  55. log_graph: false
  56. default_hp_metric: true
  57. prefix: ""
  58. # wandb:
  59. # _target_: lightning.pytorch.loggers.wandb.WandbLogger
  60. # # name: "" # name of the run (normally generated by wandb)
  61. # save_dir: "${paths.run_dir}"
  62. # offline: False
  63. # id: null # pass correct id to resume experiment!
  64. # anonymous: null # enable anonymous logging
  65. # project: "fish-speech"
  66. # log_model: False # upload lightning ckpts
  67. # prefix: "" # a string to put at the beginning of metric keys
  68. # # entity: "" # set to name of your wandb team
  69. # group: ""
  70. # tags: ["vq", "hq", "finetune"]
  71. # job_type: ""
  72. # Loop
  73. train: true
  74. test: false