|
@@ -1,6 +1,6 @@
|
|
import os
|
|
import os
|
|
import torch
|
|
import torch
|
|
-from transformers import AutoModel, AutoConfig, CLIPProcessor
|
|
|
|
|
|
+from transformers import AutoModel, AutoConfig, AutoProcessor, CLIPImageProcessor, AutoTokenizer
|
|
|
|
|
|
MODEL_NAME = "BAAI/EVA-CLIP-8B"
|
|
MODEL_NAME = "BAAI/EVA-CLIP-8B"
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -11,12 +11,32 @@ TRUST_REMOTE_CODE = True
|
|
|
|
|
|
print(f"[model_config] Loading {MODEL_NAME} on {DEVICE} dtype={DTYPE} ...")
|
|
print(f"[model_config] Loading {MODEL_NAME} on {DEVICE} dtype={DTYPE} ...")
|
|
|
|
|
|
-config = AutoConfig.from_pretrained(pretrained_model_name_or_path=MODEL_NAME, trust_remote_code=TRUST_REMOTE_CODE)
|
|
|
|
|
|
+# 加载模型配置
|
|
|
|
+config = AutoConfig.from_pretrained(
|
|
|
|
+ pretrained_model_name_or_path=MODEL_NAME,
|
|
|
|
+ trust_remote_code=TRUST_REMOTE_CODE
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+# 加载模型
|
|
model = AutoModel.from_pretrained(
|
|
model = AutoModel.from_pretrained(
|
|
- pretrained_model_name_or_path=MODEL_NAME, config=config, trust_remote_code=TRUST_REMOTE_CODE
|
|
|
|
|
|
+ pretrained_model_name_or_path=MODEL_NAME,
|
|
|
|
+ config=config,
|
|
|
|
+ trust_remote_code=TRUST_REMOTE_CODE
|
|
).to(dtype=DTYPE, device=DEVICE).eval()
|
|
).to(dtype=DTYPE, device=DEVICE).eval()
|
|
|
|
|
|
-processor = CLIPProcessor.from_pretrained(pretrained_model_name_or_path=MODEL_NAME, trust_remote_code=TRUST_REMOTE_CODE)
|
|
|
|
|
|
+# 优先尝试 AutoProcessor(适配EVA-CLIP这种特殊情况)
|
|
|
|
+try:
|
|
|
|
+ processor = AutoProcessor.from_pretrained(
|
|
|
|
+ pretrained_model_name_or_path=MODEL_NAME,
|
|
|
|
+ trust_remote_code=TRUST_REMOTE_CODE
|
|
|
|
+ )
|
|
|
|
+except Exception as e:
|
|
|
|
+ print(f"[warning] AutoProcessor 加载失败: {e}")
|
|
|
|
+ print("[info] 尝试手动组合 ImageProcessor + Tokenizer ...")
|
|
|
|
+ processor = {
|
|
|
|
+ "image_processor": CLIPImageProcessor.from_pretrained(MODEL_NAME),
|
|
|
|
+ "tokenizer": AutoTokenizer.from_pretrained(MODEL_NAME),
|
|
|
|
+ }
|
|
|
|
|
|
def get_model():
|
|
def get_model():
|
|
return model, processor, DEVICE, DTYPE, MAX_BATCH
|
|
return model, processor, DEVICE, DTYPE, MAX_BATCH
|