extract_model.py 536 B

123456789101112131415161718192021
  1. import click
  2. import torch
  3. from loguru import logger
  4. @click.command()
  5. @click.argument("model_path")
  6. @click.argument("output_path")
  7. def main(model_path, output_path):
  8. if model_path == output_path:
  9. logger.error("Model path and output path are the same")
  10. return
  11. logger.info(f"Loading model from {model_path}")
  12. state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
  13. torch.save(state_dict, output_path)
  14. logger.info(f"Model saved to {output_path}")
  15. if __name__ == "__main__":
  16. main()