extract_model.py 727 B

1234567891011121314151617181920212223242526272829
  1. import click
  2. import torch
  3. from loguru import logger
  4. @click.command()
  5. @click.argument("model_path")
  6. @click.argument("output_path")
  7. def main(model_path, output_path):
  8. if model_path == output_path:
  9. logger.error("Model path and output path are the same")
  10. click.Abort()
  11. logger.info(f"Loading model from {model_path}")
  12. state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
  13. logger.info("Extracting model")
  14. state_dict = {
  15. state_dict: value
  16. for state_dict, value in state_dict.items()
  17. if state_dict.startswith("model.")
  18. }
  19. torch.save(state_dict, output_path)
  20. logger.info(f"Model saved to {output_path}")
  21. if __name__ == "__main__":
  22. main()