base.yaml 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Base configuration for training a model
  2. paths:
  3. run_dir: results/${project}
  4. ckpt_dir: ${paths.run_dir}/checkpoints
  5. hydra:
  6. run:
  7. dir: ${paths.run_dir}
  8. # Lightning Trainer
  9. trainer:
  10. _target_: lightning.pytorch.trainer.Trainer
  11. default_root_dir: ${paths.run_dir}
  12. accelerator: gpu
  13. num_nodes: 1
  14. devices: auto
  15. strategy:
  16. _target_: lightning.pytorch.strategies.DDPStrategy
  17. precision: bf16-mixed
  18. # disable validation by epoch end
  19. check_val_every_n_epoch: null
  20. val_check_interval: 5000
  21. max_steps: 100_000
  22. # Use torch.backends.cudnn.benchmark to speed up training
  23. benchmark: true
  24. # Callbacks
  25. callbacks:
  26. model_checkpoint:
  27. _target_: lightning.pytorch.callbacks.ModelCheckpoint
  28. dirpath: ${paths.ckpt_dir}
  29. filename: "step_{step:09d}"
  30. save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
  31. save_top_k: 5 # save 5 latest checkpoints
  32. monitor: step # use step to monitor checkpoints
  33. mode: max # save the latest checkpoint with the highest global_step
  34. every_n_epochs: null # don't save checkpoints by epoch end
  35. every_n_train_steps: 5000 # save checkpoints every 5000 steps
  36. auto_insert_metric_name: false
  37. model_summary:
  38. _target_: lightning.pytorch.callbacks.ModelSummary
  39. max_depth: 2 # the maximum depth of layer nesting that the summary will include
  40. learning_rate_monitor:
  41. _target_: lightning.pytorch.callbacks.LearningRateMonitor
  42. logging_interval: step
  43. log_momentum: false
  44. grad_norm_monitor:
  45. _target_: fish_speech.callbacks.GradNormMonitor
  46. norm_type: 2
  47. logging_interval: step
  48. # Logger
  49. logger:
  50. tensorboard:
  51. _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
  52. save_dir: "${paths.run_dir}/tensorboard/"
  53. name: null
  54. log_graph: false
  55. default_hp_metric: true
  56. prefix: ""
  57. # Loop
  58. train: true
  59. test: false