瀏覽代碼

add img embedding.py

luojunhui 4 周之前
父節點
當前提交
634e7f7e98
共有 1 個文件被更改,包括 5 次插入5 次删除
  1. 5 5
      applications/clip_embedding/embedding.py

+ 5 - 5
applications/clip_embedding/embedding.py

@@ -28,13 +28,13 @@ async def embed_image_url(img_url_list: List[str]):
     for chunk_start in range(0, len(images), MAX_BATCH):
         chunk = images[chunk_start:chunk_start + MAX_BATCH]
 
-        # ✅ 用 image_processor,不再用混合 processor
+        # 只取 pixel_values,不传 text
         inputs = image_processor(images=chunk, return_tensors="pt")
-        inputs = {k: v.to(DEVICE, dtype=DTYPE) if hasattr(v, "to") else v for k, v in inputs.items()}
+        pixel_values = inputs["pixel_values"].to(DEVICE, dtype=DTYPE)
 
-        outputs = model(**inputs)
-        # 某些实现是 outputs.last_hidden_state,某些是 outputs.image_embeds
-        feats = outputs.image_embeds if hasattr(outputs, "image_embeds") else outputs.last_hidden_state
+        with torch.no_grad():
+            # ✅ 调用图像编码器
+            feats = model.get_image_features(pixel_values=pixel_values)
 
         feats = _normalize(feats)
         outputs.extend(_to_list(feats))