| 1234567891011121314151617181920212223242526272829303132333435363738394041 |
- import torch
- from transformers import PreTrainedModel
- from ..utils import torch_gc
- class CPUTextEncoderWrapper(PreTrainedModel):
- def __init__(self, text_encoder, torch_dtype):
- super().__init__(text_encoder.config)
- self.config = text_encoder.config
- self._device = text_encoder.device
- # cpu not support float16
- self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
- self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
- self.torch_dtype = torch_dtype
- del text_encoder
- torch_gc()
- def __call__(self, x, **kwargs):
- input_device = x.device
- original_output = self.text_encoder(x.to(self.text_encoder.device), **kwargs)
- for k, v in original_output.items():
- if isinstance(v, tuple):
- original_output[k] = [
- v[i].to(input_device).to(self.torch_dtype) for i in range(len(v))
- ]
- else:
- original_output[k] = v.to(input_device).to(self.torch_dtype)
- return original_output
- @property
- def dtype(self):
- return self.torch_dtype
- @property
- def device(self) -> torch.device:
- """
- `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
- device).
- """
- return self._device
|