clip_model.py 738 B

12345678910111213141516171819202122
  1. import os
  2. import torch
  3. from transformers import AutoModel, AutoConfig, CLIPProcessor
  4. MODEL_NAME = "BAAI/EVA-CLIP-8B"
  5. DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
  6. DTYPE = torch.float8 if DEVICE == "cuda" else torch.float32
  7. MAX_BATCH = int(os.getenv("MAX_BATCH", "32"))
  8. TRUST_REMOTE_CODE = True
  9. print(f"[model_config] Loading {MODEL_NAME} on {DEVICE} dtype={DTYPE} ...")
  10. config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=TRUST_REMOTE_CODE)
  11. model = AutoModel.from_pretrained(
  12. MODEL_NAME, config=config, trust_remote_code=TRUST_REMOTE_CODE
  13. ).to(dtype=DTYPE, device=DEVICE).eval()
  14. processor = CLIPProcessor.from_pretrained(MODEL_NAME)
  15. def get_model():
  16. return model, processor, DEVICE, DTYPE, MAX_BATCH