浏览代码

search_method-v1

luojunhui 3 周之前
父节点
当前提交
29b27dd1af
共有 4 个文件被更改,包括 85 次插入74 次删除
  1. 2 1
      applications/config/model_config.py
  2. 2 2
      applications/utils/milvus/__init__.py
  3. 76 45
      applications/utils/milvus/search.py
  4. 5 26
      routes/buleprint.py

+ 2 - 1
applications/config/model_config.py

@@ -6,6 +6,7 @@ LOCAL_MODEL_CONFIG = {
 
 
 DEFAULT_MODEL = "Qwen3-Embedding-4B"
 DEFAULT_MODEL = "Qwen3-Embedding-4B"
 
 
-VLLM_SERVER_URL = "http://vllm-qwen:8000/v1/embeddings"
+# VLLM_SERVER_URL = "http://vllm-qwen:8000/v1/embeddings"
+VLLM_SERVER_URL = "http://192.168.100.31:8000/v1/embeddings"
 
 
 DEV_VLLM_SERVER_URL = "http://192.168.100.31:8000/v1/embeddings"
 DEV_VLLM_SERVER_URL = "http://192.168.100.31:8000/v1/embeddings"

+ 2 - 2
applications/utils/milvus/__init__.py

@@ -1,6 +1,6 @@
 from .collection import milvus_collection
 from .collection import milvus_collection
 from .functions import async_insert_chunk, async_search_chunk
 from .functions import async_insert_chunk, async_search_chunk
-from .search import MilvusSearcher
+from .search import MilvusSearch
 
 
 
 
-__all__ = ["milvus_collection", "async_insert_chunk", "async_search_chunk", "MilvusSearcher"]
+__all__ = ["milvus_collection", "async_insert_chunk", "async_search_chunk", "MilvusSearch"]

+ 76 - 45
applications/utils/milvus/search.py

@@ -1,27 +1,47 @@
 import asyncio
 import asyncio
 from typing import List, Optional, Dict, Any, Union
 from typing import List, Optional, Dict, Any, Union
 
 
-
-class MilvusSearcher:
+class MilvusBase:
 
 
     output_fields = [
     output_fields = [
         "doc_id",
         "doc_id",
         "chunk_id",
         "chunk_id",
-        "summary",
-        "topic",
-        "domain",
-        "task_type",
-        "keywords",
-        "concepts",
-        "questions",
-        "entities",
-        "tokens",
-        "topic_purity",
+        # "summary",
+        # "topic",
+        # "domain",
+        # "task_type",
+        # "keywords",
+        # "concepts",
+        # "questions",
+        # "entities",
+        # "tokens",
+        # "topic_purity",
     ]
     ]
 
 
     def __init__(self, milvus_pool):
     def __init__(self, milvus_pool):
         self.milvus_pool = milvus_pool
         self.milvus_pool = milvus_pool
 
 
+    @staticmethod
+    def hits_to_json(hits):
+        if not hits:
+            return []
+
+        special_keys = {"entities", "concepts", "questions", "keywords"}
+        return [
+            {
+                "pk": hit.id,
+                "score": 1 - hit.distance,
+                **{
+                    key: list(value) if key in special_keys else value
+                    for key, value in (hit.get("entity", {}) or {}).items()
+                }
+            }
+            for hit in hits[0]
+        ]
+
+
+class MilvusSearch(MilvusBase):
+
     # 通过向量匹配
     # 通过向量匹配
     async def vector_search(
     async def vector_search(
         self,
         self,
@@ -34,7 +54,8 @@ class MilvusSearcher:
         """向量搜索,可选过滤"""
         """向量搜索,可选过滤"""
         if search_params is None:
         if search_params is None:
             search_params = {"metric_type": "COSINE", "params": {"ef": 64}}
             search_params = {"metric_type": "COSINE", "params": {"ef": 64}}
-        return await asyncio.to_thread(
+
+        response = await asyncio.to_thread(
             self.milvus_pool.search,
             self.milvus_pool.search,
             data=[query_vec],
             data=[query_vec],
             anns_field=anns_field,
             anns_field=anns_field,
@@ -43,30 +64,9 @@ class MilvusSearcher:
             expr=expr,
             expr=expr,
             output_fields=self.output_fields,
             output_fields=self.output_fields,
         )
         )
+        res = self.hits_to_json(response)
+        return res
 
 
-    # 通过doc_id + chunk_id 获取数据
-    async def get_by_doc_and_chunk(self, doc_id: str, chunk_id: int):
-        expr = f'doc_id == "{doc_id}" and chunk_id == {chunk_id}'
-        return await asyncio.to_thread(
-            self.milvus_pool.query,
-            expr=expr,
-            output_fields=self.output_fields,
-        )
-
-    # 只按 metadata 条件查询
-    async def filter_search(self, filters: Dict[str, Union[str, int, float]]):
-        exprs = []
-        for k, v in filters.items():
-            if isinstance(v, str):
-                exprs.append(f'{k} == "{v}"')
-            else:
-                exprs.append(f"{k} == {v}")
-        expr = " and ".join(exprs)
-        return await asyncio.to_thread(
-            self.milvus_pool.query,
-            expr=expr,
-            output_fields=self.output_fields,
-        )
 
 
     # 混合搜索(向量 + metadata)
     # 混合搜索(向量 + metadata)
     async def hybrid_search(
     async def hybrid_search(
@@ -86,17 +86,10 @@ class MilvusSearcher:
                     parts.append(f"{k} == {v}")
                     parts.append(f"{k} == {v}")
             expr = " and ".join(parts)
             expr = " and ".join(parts)
 
 
-        return await self.vector_search(
+        response = await self.vector_search(
             query_vec=query_vec, anns_field=anns_field, limit=limit, expr=expr
             query_vec=query_vec, anns_field=anns_field, limit=limit, expr=expr
         )
         )
-
-    # 通过主键获取milvus数据
-    async def get_by_id(self, pk: int):
-        return await asyncio.to_thread(
-            self.milvus_pool.query,
-            expr=f"id == {pk}",
-            output_fields=self.output_fields,
-        )
+        return self.hits_to_json(response)
 
 
     async def search_by_strategy(
     async def search_by_strategy(
         self,
         self,
@@ -133,3 +126,41 @@ class MilvusSearcher:
             {"pk": k[0], "doc_id": k[1], "chunk_id": k[2], "score": v}
             {"pk": k[0], "doc_id": k[1], "chunk_id": k[2], "score": v}
             for k, v in ranked
             for k, v in ranked
         ]
         ]
+
+
+class MilvusQuery(MilvusBase):
+    # 通过doc_id + chunk_id 获取数据
+    async def get_by_doc_and_chunk(self, doc_id: str, chunk_id: int):
+        expr = f'doc_id == "{doc_id}" and chunk_id == {chunk_id}'
+        response = await asyncio.to_thread(
+            self.milvus_pool.query,
+            expr=expr,
+            output_fields=self.output_fields,
+        )
+        return self.hits_to_json(response)
+
+    # 只按 metadata 条件查询
+    async def filter_search(self, filters: Dict[str, Union[str, int, float]]):
+        exprs = []
+        for k, v in filters.items():
+            if isinstance(v, str):
+                exprs.append(f'{k} == "{v}"')
+            else:
+                exprs.append(f"{k} == {v}")
+        expr = " and ".join(exprs)
+        response = await asyncio.to_thread(
+            self.milvus_pool.query,
+            expr=expr,
+            output_fields=self.output_fields,
+        )
+        print(response)
+        return self.hits_to_json(response)
+
+    # 通过主键获取milvus数据
+    async def get_by_id(self, pk: int):
+        response = await asyncio.to_thread(
+            self.milvus_pool.query,
+            expr=f"id == {pk}",
+            output_fields=self.output_fields,
+        )
+        return self.hits_to_json(response)

+ 5 - 26
routes/buleprint.py

@@ -1,3 +1,4 @@
+import traceback
 import uuid
 import uuid
 
 
 from quart import Blueprint, jsonify, request
 from quart import Blueprint, jsonify, request
@@ -5,7 +6,7 @@ from quart import Blueprint, jsonify, request
 from applications.config import DEFAULT_MODEL, LOCAL_MODEL_CONFIG, ChunkerConfig, WEIGHT_MAP
 from applications.config import DEFAULT_MODEL, LOCAL_MODEL_CONFIG, ChunkerConfig, WEIGHT_MAP
 from applications.api import get_basic_embedding
 from applications.api import get_basic_embedding
 from applications.async_task import ChunkEmbeddingTask
 from applications.async_task import ChunkEmbeddingTask
-from applications.utils.milvus import MilvusSearcher
+from applications.utils.milvus import MilvusSearch
 
 
 server_bp = Blueprint("api", __name__, url_prefix="/api")
 server_bp = Blueprint("api", __name__, url_prefix="/api")
 
 
@@ -42,28 +43,15 @@ def server_routes(mysql_db, vector_db):
         if not search_type:
         if not search_type:
             return jsonify({"error": "missing search_type"}), 400
             return jsonify({"error": "missing search_type"}), 400
 
 
-        searcher = MilvusSearcher(vector_db)
+        searcher = MilvusSearch(vector_db)
 
 
         try:
         try:
             # 统一参数
             # 统一参数
             expr = body.get("expr")
             expr = body.get("expr")
             search_params = body.get("search_params") or {"metric_type": "COSINE", "params": {"ef": 64}}
             search_params = body.get("search_params") or {"metric_type": "COSINE", "params": {"ef": 64}}
-            limit = body.get("limit", 5)
+            limit = body.get("limit", 50)
             query = body.get("query")
             query = body.get("query")
 
 
-            # 定义不同搜索策略
-            async def by_pk_id():
-                pk_id = body.get("id")
-                if not pk_id:
-                    return {"error": "missing id"}
-                return await searcher.get_by_id(pk_id)
-
-            async def by_doc_id():
-                doc_id, chunk_id = body.get("doc_id"), body.get("chunk_id")
-                if not doc_id or chunk_id is None:
-                    return {"error": "missing doc_id or chunk_id"}
-                return await searcher.get_by_doc_and_chunk(doc_id, chunk_id)
-
             async def by_vector():
             async def by_vector():
                 if not query:
                 if not query:
                     return {"error": "missing query"}
                     return {"error": "missing query"}
@@ -77,12 +65,6 @@ def server_routes(mysql_db, vector_db):
                     limit=limit,
                     limit=limit,
                 )
                 )
 
 
-            async def by_filter():
-                filter_map = body.get("filter_map")
-                if not filter_map:
-                    return {"error": "missing filter_map"}
-                return await searcher.filter_search(filter_map)
-
             async def hybrid():
             async def hybrid():
                 if not query:
                 if not query:
                     return {"error": "missing query"}
                     return {"error": "missing query"}
@@ -108,10 +90,7 @@ def server_routes(mysql_db, vector_db):
 
 
             # dispatch table
             # dispatch table
             handlers = {
             handlers = {
-                "pk_id": by_pk_id,
-                "by_doc_id": by_doc_id,
                 "by_vector": by_vector,
                 "by_vector": by_vector,
-                "by_filter": by_filter,
                 "hybrid": hybrid,
                 "hybrid": hybrid,
                 "strategy": strategy,
                 "strategy": strategy,
             }
             }
@@ -123,6 +102,6 @@ def server_routes(mysql_db, vector_db):
             return jsonify(result)
             return jsonify(result)
 
 
         except Exception as e:
         except Exception as e:
-            return jsonify({"error": str(e)}), 500
+            return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
 
 
     return server_bp
     return server_bp