train.py 628 B

1234567891011121314151617181920212223
  1. import torch
  2. from lightning.fabric import Fabric
  3. import hydra
  4. from omegaconf import DictConfig, OmegaConf
  5. import pyrootutils
  6. # Allow TF32 on Ampere GPUs
  7. torch.set_float32_matmul_precision("high")
  8. torch.backends.cudnn.allow_tf32 = True
  9. # register eval resolver and root
  10. pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
  11. OmegaConf.register_new_resolver("eval", eval)
  12. # flake8: noqa: E402
  13. from speech_lm.dataset import build_dataset
  14. @hydra.main(version_base="1.3", config_path="./configs", config_name="pretrain.yaml")
  15. def main(cfg: DictConfig):
  16. print(cfg)
  17. if __name__ == "__main__":
  18. main()