train_colbert.py 1.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import os
  2. from pathlib import Path
  3. import configue
  4. import typer
  5. from colpali_engine.trainer.colmodel_training import ColModelTraining, ColModelTrainingConfig
  6. from colpali_engine.utils.gpu_stats import print_gpu_utilization
  7. app = typer.Typer(pretty_exceptions_enable=False)
  8. @app.command()
  9. def main(config_file: Path) -> None:
  10. """
  11. Training script for ColVision models.
  12. Args:
  13. config_file (Path): Path to the configuration file.
  14. """
  15. print_gpu_utilization()
  16. print("Loading config")
  17. config = configue.load(config_file, sub_path="config")
  18. print("Creating Setup")
  19. if isinstance(config, ColModelTrainingConfig):
  20. training_app = ColModelTraining(config)
  21. else:
  22. raise ValueError("Config must be of type ColModelTrainingConfig")
  23. if config.run_train:
  24. print("Training model")
  25. training_app.train()
  26. training_app.save()
  27. os.system(f"cp {config_file} {training_app.config.output_dir}/training_config.yml")
  28. print("Done!")
  29. if __name__ == "__main__":
  30. app()