|
@@ -1,5 +1,6 @@
|
|
|
import traceback
|
|
|
import uuid
|
|
|
+from typing import Dict, Any
|
|
|
|
|
|
from quart import Blueprint, jsonify, request
|
|
|
|
|
@@ -7,122 +8,120 @@ from applications.config import (
|
|
|
DEFAULT_MODEL,
|
|
|
LOCAL_MODEL_CONFIG,
|
|
|
ChunkerConfig,
|
|
|
- WEIGHT_MAP,
|
|
|
+ BASE_MILVUS_SEARCH_PARAMS,
|
|
|
)
|
|
|
+from applications.resource import get_resource_manager
|
|
|
from applications.api import get_basic_embedding
|
|
|
from applications.api import get_img_embedding
|
|
|
from applications.async_task import ChunkEmbeddingTask
|
|
|
-from applications.utils.milvus import MilvusSearch
|
|
|
+from applications.search import HybridSearch
|
|
|
+
|
|
|
|
|
|
server_bp = Blueprint("api", __name__, url_prefix="/api")
|
|
|
|
|
|
|
|
|
-def server_routes(mysql_db, vector_db):
|
|
|
-
|
|
|
- @server_bp.route("/embed", methods=["POST"])
|
|
|
- async def embed():
|
|
|
- body = await request.get_json()
|
|
|
- text = body.get("text")
|
|
|
- model_name = body.get("model", DEFAULT_MODEL)
|
|
|
- if not LOCAL_MODEL_CONFIG.get(model_name):
|
|
|
- return jsonify({"error": "error model"})
|
|
|
-
|
|
|
- embedding = await get_basic_embedding(text, model_name)
|
|
|
- return jsonify({"embedding": embedding})
|
|
|
-
|
|
|
- @server_bp.route("/img_embed", methods=["POST"])
|
|
|
- async def img_embed():
|
|
|
- body = await request.get_json()
|
|
|
- url_list = body.get("url_list")
|
|
|
- if not url_list:
|
|
|
- return jsonify({"error": "error url_list"})
|
|
|
-
|
|
|
- embedding = await get_img_embedding(url_list)
|
|
|
- return jsonify(embedding)
|
|
|
-
|
|
|
- @server_bp.route("/chunk", methods=["POST"])
|
|
|
- async def chunk():
|
|
|
- body = await request.get_json()
|
|
|
- text = body.get("text", "")
|
|
|
- text = text.strip()
|
|
|
- if not text:
|
|
|
- return jsonify({"error": "error text"})
|
|
|
- doc_id = f"doc-{uuid.uuid4()}"
|
|
|
- chunk_task = ChunkEmbeddingTask(
|
|
|
- mysql_db, vector_db, cfg=ChunkerConfig(), doc_id=doc_id
|
|
|
- )
|
|
|
- doc_id = await chunk_task.deal(body)
|
|
|
- return jsonify({"doc_id": doc_id})
|
|
|
-
|
|
|
- @server_bp.route("/search", methods=["POST"])
|
|
|
- async def search():
|
|
|
- body = await request.get_json()
|
|
|
- search_type = body.get("search_type")
|
|
|
- if not search_type:
|
|
|
- return jsonify({"error": "missing search_type"}), 400
|
|
|
-
|
|
|
- searcher = MilvusSearch(vector_db)
|
|
|
-
|
|
|
- try:
|
|
|
- # 统一参数
|
|
|
- expr = body.get("expr")
|
|
|
- search_params = body.get("search_params") or {
|
|
|
- "metric_type": "COSINE",
|
|
|
- "params": {"ef": 64},
|
|
|
- }
|
|
|
- limit = body.get("limit", 50)
|
|
|
- query = body.get("query")
|
|
|
-
|
|
|
- async def by_vector():
|
|
|
- if not query:
|
|
|
- return {"error": "missing query"}
|
|
|
- field = body.get("field", "vector_text")
|
|
|
- query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
|
|
|
- return await searcher.vector_search(
|
|
|
- query_vec=query_vec,
|
|
|
- anns_field=field,
|
|
|
- expr=expr,
|
|
|
+@server_bp.route("/embed", methods=["POST"])
|
|
|
+async def embed():
|
|
|
+ body = await request.get_json()
|
|
|
+ text = body.get("text")
|
|
|
+ model_name = body.get("model", DEFAULT_MODEL)
|
|
|
+ if not LOCAL_MODEL_CONFIG.get(model_name):
|
|
|
+ return jsonify({"error": "error model"})
|
|
|
+
|
|
|
+ embedding = await get_basic_embedding(text, model_name)
|
|
|
+ return jsonify({"embedding": embedding})
|
|
|
+
|
|
|
+
|
|
|
+@server_bp.route("/img_embed", methods=["POST"])
|
|
|
+async def img_embed():
|
|
|
+ body = await request.get_json()
|
|
|
+ url_list = body.get("url_list")
|
|
|
+ if not url_list:
|
|
|
+ return jsonify({"error": "error url_list"})
|
|
|
+
|
|
|
+ embedding = await get_img_embedding(url_list)
|
|
|
+ return jsonify(embedding)
|
|
|
+
|
|
|
+
|
|
|
+@server_bp.route("/chunk", methods=["POST"])
|
|
|
+async def chunk():
|
|
|
+ body = await request.get_json()
|
|
|
+ text = body.get("text", "")
|
|
|
+ text = text.strip()
|
|
|
+ if not text:
|
|
|
+ return jsonify({"error": "error text"})
|
|
|
+ resource = get_resource_manager()
|
|
|
+ doc_id = f"doc-{uuid.uuid4()}"
|
|
|
+ chunk_task = ChunkEmbeddingTask(
|
|
|
+ resource.mysql_client,
|
|
|
+ resource.milvus_client,
|
|
|
+ cfg=ChunkerConfig(),
|
|
|
+ doc_id=doc_id,
|
|
|
+ es_pool=resource.es_client,
|
|
|
+ )
|
|
|
+ doc_id = await chunk_task.deal(body)
|
|
|
+ return jsonify({"doc_id": doc_id})
|
|
|
+
|
|
|
+
|
|
|
+@server_bp.route("/search", methods=["POST"])
|
|
|
+async def search():
|
|
|
+ """
|
|
|
+ filters: Dict[str, Any], # 条件过滤
|
|
|
+ query_vec: List[float], # query 的向量
|
|
|
+ anns_field: str = "vector_text", # query指定的向量空间
|
|
|
+ search_params: Optional[Dict[str, Any]] = None, # 向量距离方式
|
|
|
+ query_text: str = None, #是否通过 topic 倒排
|
|
|
+ _source=False, # 是否返回元数据
|
|
|
+ es_size: int = 10000, #es 第一层过滤数量
|
|
|
+ sort_by: str = None, # 排序
|
|
|
+ milvus_size: int = 10 # milvus粗排返回数量
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ body = await request.get_json()
|
|
|
+
|
|
|
+ # 解析数据
|
|
|
+ search_type: str = body.get("search_type")
|
|
|
+ filters: Dict[str, Any] = body.get("filters", {})
|
|
|
+ anns_field: str = body.get("anns_field", "vector_text")
|
|
|
+ search_params: Dict[str, Any] = body.get("search_params", BASE_MILVUS_SEARCH_PARAMS)
|
|
|
+ query_text: str = body.get("query_text")
|
|
|
+ _source: bool = body.get("_source", False)
|
|
|
+ es_size: int = body.get("es_size", 10000)
|
|
|
+ sort_by: str = body.get("sort_by")
|
|
|
+ milvus_size: int = body.get("milvus", 20)
|
|
|
+ limit: int = body.get("limit", 10)
|
|
|
+ if not query_text:
|
|
|
+ return jsonify({"error": "error query_text"})
|
|
|
+
|
|
|
+ query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
|
|
|
+ resource = get_resource_manager()
|
|
|
+ search_engine = HybridSearch(
|
|
|
+ milvus_pool=resource.milvus_client, es_pool=resource.es_client
|
|
|
+ )
|
|
|
+ try:
|
|
|
+ match search_type:
|
|
|
+ case "base":
|
|
|
+ response = await search_engine.base_vector_search(
|
|
|
+ query_vec=query_vector,
|
|
|
+ anns_field=anns_field,
|
|
|
search_params=search_params,
|
|
|
limit=limit,
|
|
|
)
|
|
|
-
|
|
|
- async def hybrid():
|
|
|
- if not query:
|
|
|
- return {"error": "missing query"}
|
|
|
- field = body.get("field", "vector_text")
|
|
|
- query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
|
|
|
- return await searcher.hybrid_search(
|
|
|
- query_vec=query_vec,
|
|
|
- anns_field=field,
|
|
|
- filters=body.get("filter_map"),
|
|
|
- limit=limit,
|
|
|
- )
|
|
|
-
|
|
|
- async def strategy():
|
|
|
- if not query:
|
|
|
- return {"error": "missing query"}
|
|
|
- query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
|
|
|
- return await searcher.search_by_strategy(
|
|
|
- query_vec=query_vec,
|
|
|
- weight_map=body.get("weight_map", WEIGHT_MAP),
|
|
|
- expr=expr,
|
|
|
- limit=limit,
|
|
|
+ return jsonify(response), 200
|
|
|
+ case "hybrid":
|
|
|
+ response = await search_engine.hybrid_search(
|
|
|
+ filters=filters,
|
|
|
+ query_vec=query_vector,
|
|
|
+ anns_field=anns_field,
|
|
|
+ search_params=search_params,
|
|
|
+ es_size=es_size,
|
|
|
+ sort_by=sort_by,
|
|
|
+ milvus_size=milvus_size,
|
|
|
)
|
|
|
-
|
|
|
- # dispatch table
|
|
|
- handlers = {
|
|
|
- "by_vector": by_vector,
|
|
|
- "hybrid": hybrid,
|
|
|
- "strategy": strategy,
|
|
|
- }
|
|
|
-
|
|
|
- if search_type not in handlers:
|
|
|
- return jsonify({"error": "invalid search_type"}), 400
|
|
|
-
|
|
|
- result = await handlers[search_type]()
|
|
|
- return jsonify(result)
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
|
|
|
-
|
|
|
- return server_bp
|
|
|
+ return jsonify(response), 200
|
|
|
+ case "strategy":
|
|
|
+ return jsonify({"error": "strategy not implemented"}), 405
|
|
|
+ case _:
|
|
|
+ return jsonify({"error": "error search_type"}), 200
|
|
|
+ except Exception as e:
|
|
|
+ return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
|