clip_model.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334
  1. import os
  2. import torch
  3. from transformers import AutoModel, AutoConfig, AutoTokenizer
  4. MODEL_NAME = "BAAI/EVA-CLIP-8B-plus"
  5. DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
  6. DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
  7. MAX_BATCH = int(os.getenv("MAX_BATCH", "32"))
  8. print(f"[model_config] Loading {MODEL_NAME} on {DEVICE} dtype={DTYPE} ...")
  9. config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
  10. model = AutoModel.from_pretrained(
  11. MODEL_NAME, config=config, trust_remote_code=True
  12. ).to(dtype=DTYPE, device=DEVICE).eval()
  13. try:
  14. from transformers import CLIPImageProcessor
  15. image_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
  16. except Exception:
  17. print("[warning] EVA-CLIP 没有预处理配置,使用默认参数构造 ImageProcessor")
  18. from transformers import CLIPImageProcessor
  19. image_processor = CLIPImageProcessor(size={"shortest_edge": 224}, resample=3,
  20. crop_size={"height": 224, "width": 224},
  21. image_mean=[0.48145466, 0.4578275, 0.40821073],
  22. image_std=[0.26862954, 0.26130258, 0.27577711])
  23. # 如果后续要做 text embedding,可以加 tokenizer
  24. tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
  25. def get_model():
  26. return model, image_processor, tokenizer, DEVICE, DTYPE, MAX_BATCH