| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- # Base configuration for training a model
- paths:
- run_dir: results/${project}
- ckpt_dir: ${paths.run_dir}/checkpoints
- hydra:
- run:
- dir: ${paths.run_dir}
- # Lightning Trainer
- trainer:
- _target_: lightning.pytorch.trainer.Trainer
- default_root_dir: ${paths.run_dir}
- accelerator: gpu
- num_nodes: 1
- devices: auto
- strategy:
- _target_: lightning.pytorch.strategies.DDPStrategy
- process_group_backend: nccl # This should be override when training on windows
- precision: bf16-mixed
- # disable validation by epoch end
- check_val_every_n_epoch: null
- val_check_interval: 5000
- max_steps: 100_000
- # Use torch.backends.cudnn.benchmark to speed up training
- benchmark: true
- # Callbacks
- callbacks:
- model_checkpoint:
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
- dirpath: ${paths.ckpt_dir}
- filename: "step_{step:09d}"
- save_last: false # additionally always save an exact copy of the last checkpoint to a file last.ckpt
- save_top_k: 5 # save 5 latest checkpoints
- monitor: step # use step to monitor checkpoints
- mode: max # save the latest checkpoint with the highest global_step
- every_n_epochs: null # don't save checkpoints by epoch end
- every_n_train_steps: 5000 # save checkpoints every 5000 steps
- auto_insert_metric_name: false
- model_summary:
- _target_: lightning.pytorch.callbacks.ModelSummary
- max_depth: 2 # the maximum depth of layer nesting that the summary will include
- learning_rate_monitor:
- _target_: lightning.pytorch.callbacks.LearningRateMonitor
- logging_interval: step
- log_momentum: false
- grad_norm_monitor:
- _target_: fish_speech.callbacks.GradNormMonitor
- norm_type: 2
- logging_interval: step
- progress_bar:
- _target_: fish_speech.callbacks.GradAccumProgressBar
- # Logger
- logger:
- tensorboard:
- _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
- save_dir: "${paths.run_dir}/tensorboard/"
- name: null
- log_graph: false
- default_hp_metric: true
- prefix: ""
- # wandb:
- # _target_: lightning.pytorch.loggers.wandb.WandbLogger
- # # name: "" # name of the run (normally generated by wandb)
- # save_dir: "${paths.run_dir}"
- # offline: False
- # id: null # pass correct id to resume experiment!
- # anonymous: null # enable anonymous logging
- # project: "fish-speech"
- # log_model: False # upload lightning ckpts
- # prefix: "" # a string to put at the beginning of metric keys
- # # entity: "" # set to name of your wandb team
- # group: ""
- # tags: ["vq", "hq", "finetune"]
- # job_type: ""
-
- # Loop
- train: true
- test: false
|