cpu_text_encoder.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import torch
  2. from transformers import PreTrainedModel
  3. from ..utils import torch_gc
  4. class CPUTextEncoderWrapper(PreTrainedModel):
  5. def __init__(self, text_encoder, torch_dtype):
  6. super().__init__(text_encoder.config)
  7. self.config = text_encoder.config
  8. self._device = text_encoder.device
  9. # cpu not support float16
  10. self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
  11. self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
  12. self.torch_dtype = torch_dtype
  13. del text_encoder
  14. torch_gc()
  15. def __call__(self, x, **kwargs):
  16. input_device = x.device
  17. original_output = self.text_encoder(x.to(self.text_encoder.device), **kwargs)
  18. for k, v in original_output.items():
  19. if isinstance(v, tuple):
  20. original_output[k] = [
  21. v[i].to(input_device).to(self.torch_dtype) for i in range(len(v))
  22. ]
  23. else:
  24. original_output[k] = v.to(input_device).to(self.torch_dtype)
  25. return original_output
  26. @property
  27. def dtype(self):
  28. return self.torch_dtype
  29. @property
  30. def device(self) -> torch.device:
  31. """
  32. `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
  33. device).
  34. """
  35. return self._device