luojunhui преди 1 месец
родител
ревизия
7a651689ed
променени са 7 файла, в които са добавени 80 реда и са изтрити 16 реда
  1. 3 2
      applications/config/__init__.py
  2. 11 5
      applications/config/model_config.py
  3. 3 2
      applications/embedding/__init__.py
  4. 18 3
      applications/embedding/basic.py
  5. 6 0
      config.toml
  6. 17 2
      routes/buleprint.py
  7. 22 2
      vector_app.py

+ 3 - 2
applications/config/__init__.py

@@ -1,6 +1,7 @@
-from .model_config import MODEL_CONFIG, DEFAULT_MODEL
+from .model_config import MODEL_CONFIG, DEFAULT_MODEL, LOCAL_MODEL_CONFIG
 
 __all__ = [
     "MODEL_CONFIG",
-    "DEFAULT_MODEL"
+    "DEFAULT_MODEL",
+    "LOCAL_MODEL_CONFIG"
 ]

+ 11 - 5
applications/config/model_config.py

@@ -1,10 +1,16 @@
 MODEL_CONFIG = {
-    "Qwen/Qwen3-Embedding-0.6B": {"url": "http://vllm-0.6b:8000/v1/embeddings", "dim": 1536},
-    "Qwen/Qwen3-Embedding-4B": {"url": "http://vllm-4b:8000/v1/embeddings", "dim": 1536},
-    "Qwen/Qwen3-Embedding-8B": {"url": "http://vllm-8b:8000/v1/embeddings", "dim": 1536},
+    "Qwen3-Embedding-0.6B": {"url": "http://vllm-0.6b:8000/v1/embeddings", "dim": 1536},
+    "Qwen3-Embedding-4B": {"url": "http://vllm-4b:8000/v1/embeddings", "dim": 1536},
+    "Qwen3-Embedding-8B": {"url": "http://vllm-8b:8000/v1/embeddings", "dim": 1536},
 }
 
-DEFAULT_MODEL = "Qwen/Qwen3-Embedding-0.6B"
+LOCAL_MODEL_CONFIG = {
+    "Qwen3-Embedding-0.6B": "models/Qwen3-Embedding-0.6B",
+    "Qwen3-Embedding-4B": "models/Qwen3-Embedding-4B",
+    "Qwen3-Embedding-8B": "models/Qwen3-Embedding-8B",
+}
+
+DEFAULT_MODEL = "Qwen3-Embedding-0.6B"
 
 
-__all__ = ["MODEL_CONFIG", "DEFAULT_MODEL"]
+__all__ = ["MODEL_CONFIG", "DEFAULT_MODEL", "LOCAL_MODEL_CONFIG"]

+ 3 - 2
applications/embedding/__init__.py

@@ -1,5 +1,6 @@
-from .basic import get_basic_embedding
+from .basic import get_basic_embedding, get_local_embedding
 
 __all__ = [
-    "get_basic_embedding"
+    "get_basic_embedding",
+    "get_local_embedding"
 ]

+ 18 - 3
applications/embedding/basic.py

@@ -1,4 +1,4 @@
-from applications.config import MODEL_CONFIG
+from applications.config import MODEL_CONFIG, LOCAL_MODEL_CONFIG
 from applications.utils import AsyncHttpClient
 
 
@@ -7,7 +7,7 @@ async def get_basic_embedding(text: str, model: str):
     embedding text into vectors
     :param text:
     :param model:
-    :return:
+    :return:tong
     """
     cfg = MODEL_CONFIG[model]
     async with AsyncHttpClient(timeout=20) as client:
@@ -18,6 +18,21 @@ async def get_basic_embedding(text: str, model: str):
         )
         return response['data'][0]["embedding"]
 
+
+async def get_local_embedding(text, model):
+    """
+    embedding text into vectors
+    :param text:
+    :param model:
+    :return:
+    """
+    outputs = model.get_embedding([text])
+    embedding = outputs[0]
+    return embedding
+
+
+
 __all__ = [
-    "get_basic_embedding"
+    "get_basic_embedding",
+    "get_local_embedding"
 ]

+ 6 - 0
config.toml

@@ -0,0 +1,6 @@
+reload = true
+bind = "0.0.0.0:8080"
+workers = 3
+keep_alive_timeout = 120  # 保持连接的最大秒数,根据需要调整
+graceful_timeout = 30    # 重启或停止之前等待当前工作完成的时间
+loglevel = "debug"  # 日志级别

+ 17 - 2
routes/buleprint.py

@@ -1,12 +1,12 @@
 from quart import Blueprint, jsonify, request
 
 from applications.config import DEFAULT_MODEL, MODEL_CONFIG
-from applications.embedding import get_basic_embedding
+from applications.embedding import get_basic_embedding, get_local_embedding
 
 
 server_bp = Blueprint('api', __name__, url_prefix='/api')
 
-def server_routes(vector_db):
+def server_routes(llm, vector_db):
 
     @server_bp.route('/embed', methods=['POST'])
     async def embed():
@@ -23,6 +23,21 @@ def server_routes(vector_db):
             "embedding": embedding
         })
 
+    @server_bp.route('/embed_v1', methods=['POST'])
+    async def embed_v1():
+        body = await request.get_json()
+        text = body.get('text')
+        model_name = body.get('model', DEFAULT_MODEL)
+        if not MODEL_CONFIG.get(model_name):
+            return jsonify(
+                {"error": "error  model"}
+            )
+
+        embedding = await get_local_embedding(text, llm)
+        return jsonify({
+            "embedding": embedding
+        })
+
 
     @server_bp.route('/search', methods=['POST'])
     async def search():

+ 22 - 2
vector_app.py

@@ -1,14 +1,34 @@
 from quart import Quart
 from quart_cors import cors
 from pymilvus import connections
+from vllm import LLM, SamplingParams
 
+from applications.config import LOCAL_MODEL_CONFIG, DEFAULT_MODEL
 from routes import server_routes
 
 app = Quart(__name__)
 
-# 连接图数据库
+# llm
+llm = None
+
+# 连接向量数据库
 connections.connect("default", host="milvus", port="19530")
 
+
 # 注册路由
-app_route = server_routes(connections)
+app_route = server_routes(llm, connections)
 app.register_blueprint(app_route)
+
+@app.before_serving
+async def load_model():
+    """在服务启动前加载模型"""
+    global llm
+    MODEL_PATH = LOCAL_MODEL_CONFIG[DEFAULT_MODEL]
+    if llm is None:
+        llm = LLM(
+            model=MODEL_PATH,
+            dtype="float16",   # 节省显存
+            trust_remote_code=True
+        )
+        print(f"{MODEL_PATH} 模型加载完成!")
+