|
@@ -28,13 +28,13 @@ async def embed_image_url(img_url_list: List[str]):
|
|
|
for chunk_start in range(0, len(images), MAX_BATCH):
|
|
|
chunk = images[chunk_start:chunk_start + MAX_BATCH]
|
|
|
|
|
|
- # ✅ 用 image_processor,不再用混合 processor
|
|
|
+ # 只取 pixel_values,不传 text
|
|
|
inputs = image_processor(images=chunk, return_tensors="pt")
|
|
|
- inputs = {k: v.to(DEVICE, dtype=DTYPE) if hasattr(v, "to") else v for k, v in inputs.items()}
|
|
|
+ pixel_values = inputs["pixel_values"].to(DEVICE, dtype=DTYPE)
|
|
|
|
|
|
- outputs = model(**inputs)
|
|
|
- # 某些实现是 outputs.last_hidden_state,某些是 outputs.image_embeds
|
|
|
- feats = outputs.image_embeds if hasattr(outputs, "image_embeds") else outputs.last_hidden_state
|
|
|
+ with torch.no_grad():
|
|
|
+ # ✅ 调用图像编码器
|
|
|
+ feats = model.get_image_features(pixel_values=pixel_values)
|
|
|
|
|
|
feats = _normalize(feats)
|
|
|
outputs.extend(_to_list(feats))
|