Jelajahi Sumber

add img embedding.py

luojunhui 4 minggu lalu
induk
melakukan
631e8e24f4
1 mengubah file dengan 24 tambahan dan 4 penghapusan
  1. 24 4
      applications/clip_embedding/clip_model.py

+ 24 - 4
applications/clip_embedding/clip_model.py

@@ -1,6 +1,6 @@
 import os
 import os
 import torch
 import torch
-from transformers import AutoModel, AutoConfig, CLIPProcessor
+from transformers import AutoModel, AutoConfig, AutoProcessor, CLIPImageProcessor, AutoTokenizer
 
 
 MODEL_NAME = "BAAI/EVA-CLIP-8B"
 MODEL_NAME = "BAAI/EVA-CLIP-8B"
 DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -11,12 +11,32 @@ TRUST_REMOTE_CODE = True
 
 
 print(f"[model_config] Loading {MODEL_NAME} on {DEVICE} dtype={DTYPE} ...")
 print(f"[model_config] Loading {MODEL_NAME} on {DEVICE} dtype={DTYPE} ...")
 
 
-config = AutoConfig.from_pretrained(pretrained_model_name_or_path=MODEL_NAME, trust_remote_code=TRUST_REMOTE_CODE)
+# 加载模型配置
+config = AutoConfig.from_pretrained(
+    pretrained_model_name_or_path=MODEL_NAME,
+    trust_remote_code=TRUST_REMOTE_CODE
+)
+
+# 加载模型
 model = AutoModel.from_pretrained(
 model = AutoModel.from_pretrained(
-    pretrained_model_name_or_path=MODEL_NAME, config=config, trust_remote_code=TRUST_REMOTE_CODE
+    pretrained_model_name_or_path=MODEL_NAME,
+    config=config,
+    trust_remote_code=TRUST_REMOTE_CODE
 ).to(dtype=DTYPE, device=DEVICE).eval()
 ).to(dtype=DTYPE, device=DEVICE).eval()
 
 
-processor = CLIPProcessor.from_pretrained(pretrained_model_name_or_path=MODEL_NAME, trust_remote_code=TRUST_REMOTE_CODE)
+# 优先尝试 AutoProcessor(适配EVA-CLIP这种特殊情况)
+try:
+    processor = AutoProcessor.from_pretrained(
+        pretrained_model_name_or_path=MODEL_NAME,
+        trust_remote_code=TRUST_REMOTE_CODE
+    )
+except Exception as e:
+    print(f"[warning] AutoProcessor 加载失败: {e}")
+    print("[info] 尝试手动组合 ImageProcessor + Tokenizer ...")
+    processor = {
+        "image_processor": CLIPImageProcessor.from_pretrained(MODEL_NAME),
+        "tokenizer": AutoTokenizer.from_pretrained(MODEL_NAME),
+    }
 
 
 def get_model():
 def get_model():
     return model, processor, DEVICE, DTYPE, MAX_BATCH
     return model, processor, DEVICE, DTYPE, MAX_BATCH