extract_whisper_vq_weights.py 687 B

1234567891011121314151617181920212223242526
  1. from pathlib import Path
  2. import click
  3. import torch
  4. from loguru import logger
  5. @click.command()
  6. @click.argument(
  7. "input-file",
  8. type=click.Path(exists=True, dir_okay=False, file_okay=True, path_type=Path),
  9. )
  10. @click.argument(
  11. "output-file",
  12. type=click.Path(exists=False, dir_okay=False, file_okay=True, path_type=Path),
  13. )
  14. def extract(input_file: Path, output_file: Path):
  15. model = torch.load(input_file, map_location="cpu")["model"]
  16. state_dict = {k: v for k, v in model.items() if k.startswith("whisper") is False}
  17. torch.save(state_dict, output_file)
  18. logger.info(f"Saved {len(state_dict)} keys to {output_file}")
  19. if __name__ == "__main__":
  20. extract()