Explorar el Código

add img embedding.py

luojunhui hace 4 semanas
padre
commit
22192ccfbd

+ 5 - 0
applications/clip_embedding/__init__.py

@@ -0,0 +1,5 @@
+from .embedding import embed_image_url
+
+__all__ = [
+    "embed_image_url"
+]

+ 22 - 0
applications/clip_embedding/clip_model.py

@@ -0,0 +1,22 @@
+import os
+import torch
+from transformers import AutoModel, AutoConfig, CLIPProcessor
+
+MODEL_NAME = "BAAI/EVA-CLIP-8B"
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
+MAX_BATCH = int(os.getenv("MAX_BATCH", "32"))
+
+TRUST_REMOTE_CODE = True
+
+print(f"[model_config] Loading {MODEL_NAME} on {DEVICE} dtype={DTYPE} ...")
+
+config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=TRUST_REMOTE_CODE)
+model = AutoModel.from_pretrained(
+    MODEL_NAME, config=config, trust_remote_code=TRUST_REMOTE_CODE
+).to(dtype=DTYPE, device=DEVICE).eval()
+
+processor = CLIPProcessor.from_pretrained(MODEL_NAME)
+
+def get_model():
+    return model, processor, DEVICE, DTYPE, MAX_BATCH

+ 36 - 0
applications/clip_embedding/embedding.py

@@ -0,0 +1,36 @@
+import io
+import urllib.request
+from typing import List
+
+import torch
+from PIL import Image
+
+from .clip_model import get_model
+
+# init model
+model, processor, DEVICE, DTYPE, MAX_BATCH = get_model()
+
+
+def _normalize(x: torch.Tensor) -> torch.Tensor:
+    return x / (x.norm(dim=-1, keepdim=True) + 1e-12)
+
+def _to_list(x: torch.Tensor):
+    return x.detach().cpu().tolist()
+
+async def embed_image_url(img_url_list: List[str]):
+    images = []
+    for u in img_url_list:
+        with urllib.request.urlopen(u, timeout=15) as r:
+            img = Image.open(io.BytesIO(r.read())).convert("RGB")
+            images.append(img)
+
+    outputs = []
+    for chunk_start in range(0, len(images), MAX_BATCH):
+        chunk = images[chunk_start:chunk_start + MAX_BATCH]
+        inputs = processor(images=chunk, return_tensors="pt")
+        inputs = {k: v.to(DEVICE, dtype=DTYPE) if hasattr(v, "to") else v for k, v in inputs.items()}
+        feats = model.get_image_features(**inputs)
+        feats = _normalize(feats)
+        outputs.extend(_to_list(feats))
+
+    return outputs

+ 10 - 0
routes/buleprint.py

@@ -2,6 +2,7 @@ from quart import Blueprint, jsonify, request
 
 from applications.config import DEFAULT_MODEL, LOCAL_MODEL_CONFIG
 from applications.api import get_basic_embedding
+from applications.clip_embedding import embed_image_url
 
 
 server_bp = Blueprint("api", __name__, url_prefix="/api")
@@ -20,6 +21,15 @@ def server_routes(vector_db):
         embedding = await get_basic_embedding(text, model_name)
         return jsonify({"embedding": embedding})
 
+    @server_bp.route("/embed_image", methods=["POST"])
+    async def embed_image():
+        body = await request.get_json()
+        url_list = body.get("url_list", [])
+        if not url_list:
+            return jsonify({"error": "error  url_list"})
+        embeddings = await embed_image_url(url_list)
+        return jsonify({"embeddings": embeddings, "dim": len(embeddings[0])})
+
     @server_bp.route("/search", methods=["POST"])
     async def search():
         pass