Explorar o código

add img embedding.py

luojunhui hai 4 semanas
pai
achega
ab46a698ce
Modificáronse 2 ficheiros con 9 adicións e 2 borrados
  1. 8 2
      applications/clip_embedding/embedding.py
  2. 1 0
      routes/buleprint.py

+ 8 - 2
applications/clip_embedding/embedding.py

@@ -27,10 +27,16 @@ async def embed_image_url(img_url_list: List[str]):
     outputs = []
     for chunk_start in range(0, len(images), MAX_BATCH):
         chunk = images[chunk_start:chunk_start + MAX_BATCH]
-        inputs = processor(images=chunk, return_tensors="pt")
+
+        # 兼容两种情况:AutoProcessor vs dict(fallback)
+        if isinstance(processor, dict):
+            inputs = processor["image_processor"](images=chunk, return_tensors="pt")
+        else:
+            inputs = processor(images=chunk, return_tensors="pt")
+
         inputs = {k: v.to(DEVICE, dtype=DTYPE) if hasattr(v, "to") else v for k, v in inputs.items()}
         feats = model.get_image_features(**inputs)
         feats = _normalize(feats)
         outputs.extend(_to_list(feats))
 
-    return outputs
+    return outputs

+ 1 - 0
routes/buleprint.py

@@ -27,6 +27,7 @@ def server_routes(vector_db):
         url_list = body.get("url_list", [])
         if not url_list:
             return jsonify({"error": "error  url_list"})
+
         embeddings = await embed_image_url(url_list)
         return jsonify({"embeddings": embeddings, "dim": len(embeddings[0])})