extract_model.py 370 B

123456789101112
  1. import torch
  2. state_dict = torch.load(
  3. "results/text2semantic_400m/checkpoints/step_000095000.ckpt", map_location="cpu"
  4. )["state_dict"]
  5. state_dict = {
  6. state_dict.replace("model.", ""): value
  7. for state_dict, value in state_dict.items()
  8. if state_dict.startswith("model.")
  9. }
  10. torch.save(state_dict, "results/text2semantic_400m/step_000095000_weights.ckpt")