ソースを参照

Merge branch 'feature/luojunhui/2025-09-16-img-embedding' of Server/llm_vector_server into master

luojunhui 3 週間 前
コミット
05f31068ce
3 ファイル変更27 行追加2 行削除
  1. 3 1
      applications/api/__init__.py
  2. 13 1
      applications/api/embedding.py
  3. 11 0
      routes/buleprint.py

+ 3 - 1
applications/api/__init__.py

@@ -1,4 +1,6 @@
 from .deepseek import fetch_deepseek_completion
 from .embedding import get_basic_embedding
+from .embedding import get_img_embedding
 
-__all__ = ["get_basic_embedding", "fetch_deepseek_completion"]
+
+__all__ = ["get_basic_embedding", "get_img_embedding", "fetch_deepseek_completion"]

+ 13 - 1
applications/api/embedding.py

@@ -1,3 +1,4 @@
+from typing import List
 from applications.config import LOCAL_MODEL_CONFIG, VLLM_SERVER_URL, DEV_VLLM_SERVER_URL
 from applications.utils.http import AsyncHttpClient
 
@@ -20,4 +21,15 @@ async def get_basic_embedding(text: str, model: str, dev=False):
         return response["data"][0]["embedding"]
 
 
-__all__ = ["get_basic_embedding"]
+async def get_img_embedding(url_list: List[str], dev=False):
+    url = "http://117.50.199.192:8011/api/embed_image"
+    async with AsyncHttpClient(timeout=20) as client:
+        response = await client.post(
+            url=url,
+            json={"url_list": url_list},
+            headers={"Content-Type": "application/json"},
+        )
+        return response
+
+
+__all__ = ["get_basic_embedding", "get_img_embedding"]

+ 11 - 0
routes/buleprint.py

@@ -10,6 +10,7 @@ from applications.config import (
     WEIGHT_MAP,
 )
 from applications.api import get_basic_embedding
+from applications.api import get_img_embedding
 from applications.async_task import ChunkEmbeddingTask
 from applications.utils.milvus import MilvusSearch
 
@@ -29,6 +30,16 @@ def server_routes(mysql_db, vector_db):
         embedding = await get_basic_embedding(text, model_name)
         return jsonify({"embedding": embedding})
 
+    @server_bp.route("/img_embed", methods=["POST"])
+    async def img_embed():
+        body = await request.get_json()
+        url_list = body.get("url_list")
+        if not url_list:
+            return jsonify({"error": "error  url_list"})
+
+        embedding = await get_img_embedding(url_list)
+        return jsonify({"embedding": embedding})
+
     @server_bp.route("/chunk", methods=["POST"])
     async def chunk():
         body = await request.get_json()