transformers_wrappers.py 742 B

1234567891011121314151617181920
  1. import importlib
  2. if importlib.util.find_spec("transformers") is not None:
  3. from transformers import AutoProcessor, AutoTokenizer
  4. from transformers.tokenization_utils import PreTrainedTokenizer
  5. class AllPurposeWrapper:
  6. def __new__(cls, class_to_instanciate, *args, **kwargs):
  7. return class_to_instanciate.from_pretrained(*args, **kwargs)
  8. class AutoProcessorWrapper:
  9. def __new__(cls, *args, **kwargs):
  10. return AutoProcessor.from_pretrained(*args, **kwargs)
  11. class AutoTokenizerWrapper(PreTrainedTokenizer):
  12. def __new__(cls, *args, **kwargs):
  13. return AutoTokenizer.from_pretrained(*args, **kwargs)
  14. else:
  15. raise ModuleNotFoundError("Transformers must be loaded")