瀏覽代碼

search_method

luojunhui 3 周之前
父節點
當前提交
fb494b996c

+ 3 - 1
applications/config/__init__.py

@@ -8,6 +8,7 @@ from .deepseek_config import DEEPSEEK_MODEL, DEEPSEEK_API_KEY
 from .base_chunk import Chunk, ChunkerConfig
 from .milvus_config import MILVUS_CONFIG
 from .mysql_config import RAG_MYSQL_CONFIG
+from .weight_config import WEIGHT_MAP
 
 __all__ = [
     "DEFAULT_MODEL",
@@ -19,5 +20,6 @@ __all__ = [
     "Chunk",
     "ChunkerConfig",
     "MILVUS_CONFIG",
-    "RAG_MYSQL_CONFIG"
+    "RAG_MYSQL_CONFIG",
+    "WEIGHT_MAP"
 ]

+ 7 - 0
applications/config/weight_config.py

@@ -0,0 +1,7 @@
+# weight config
+
+WEIGHT_MAP = {
+    "vector_text": 0.6,
+    "question_text": 0.3,
+    "summary_text": 0.6,
+}

+ 2 - 1
applications/utils/milvus/__init__.py

@@ -1,5 +1,6 @@
 from .collection import milvus_collection
 from .functions import async_insert_chunk, async_search_chunk
+from .search import MilvusSearcher
 
 
-__all__ = ["milvus_collection", "async_insert_chunk", "async_search_chunk"]
+__all__ = ["milvus_collection", "async_insert_chunk", "async_search_chunk", "MilvusSearcher"]

+ 1 - 3
applications/utils/milvus/collection.py

@@ -14,9 +14,7 @@ milvus_collection = Collection(name="chunk_multi_embeddings", schema=schema)
 vector_index_params = {
     "index_type": "IVF_FLAT",
     "metric_type": "COSINE",
-    "params": {
-        "M": 16, "efConstruction": 200
-    }
+    "params": {"M": 16, "efConstruction": 200},
 }
 
 milvus_collection.create_index("vector_text", vector_index_params)

+ 1 - 1
applications/utils/milvus/field.py

@@ -54,7 +54,7 @@ fields = [
         max_capacity=5,
         description="隐含问题",
     ),
-FieldSchema(
+    FieldSchema(
         name="entities",
         dtype=DataType.ARRAY,
         element_type=DataType.VARCHAR,

+ 1 - 2
applications/utils/milvus/functions.py

@@ -10,8 +10,7 @@ async def async_insert_chunk(collection: pymilvus.Collection, data: Dict):
     :param data: insert data
     :return:
     """
-    res = await asyncio.to_thread(collection.insert, [data])
-    print(res)
+    await asyncio.to_thread(collection.insert, [data])
 
 
 async def async_search_chunk(

+ 135 - 0
applications/utils/milvus/search.py

@@ -0,0 +1,135 @@
+import asyncio
+from typing import List, Optional, Dict, Any, Union
+
+
+class MilvusSearcher:
+
+    output_fields = [
+        "doc_id",
+        "chunk_id",
+        "summary",
+        "topic",
+        "domain",
+        "task_type",
+        "keywords",
+        "concepts",
+        "questions",
+        "entities",
+        "tokens",
+        "topic_purity",
+    ]
+
+    def __init__(self, milvus_pool):
+        self.milvus_pool = milvus_pool
+
+    # 通过向量匹配
+    async def vector_search(
+        self,
+        query_vec: List[float],
+        anns_field: str = "vector_text",
+        limit: int = 5,
+        expr: Optional[str] = None,
+        search_params: Optional[Dict[str, Any]] = None,
+    ):
+        """向量搜索,可选过滤"""
+        if search_params is None:
+            search_params = {"metric_type": "COSINE", "params": {"ef": 64}}
+        return await asyncio.to_thread(
+            self.milvus_pool.search,
+            data=[query_vec],
+            anns_field=anns_field,
+            param=search_params,
+            limit=limit,
+            expr=expr,
+            output_fields=self.output_fields,
+        )
+
+    # 通过doc_id + chunk_id 获取数据
+    async def get_by_doc_and_chunk(self, doc_id: str, chunk_id: int):
+        expr = f'doc_id == "{doc_id}" and chunk_id == {chunk_id}'
+        return await asyncio.to_thread(
+            self.milvus_pool.query,
+            expr=expr,
+            output_fields=self.output_fields,
+        )
+
+    # 只按 metadata 条件查询
+    async def filter_search(self, filters: Dict[str, Union[str, int, float]]):
+        exprs = []
+        for k, v in filters.items():
+            if isinstance(v, str):
+                exprs.append(f'{k} == "{v}"')
+            else:
+                exprs.append(f"{k} == {v}")
+        expr = " and ".join(exprs)
+        return await asyncio.to_thread(
+            self.milvus_pool.query,
+            expr=expr,
+            output_fields=self.output_fields,
+        )
+
+    # 混合搜索(向量 + metadata)
+    async def hybrid_search(
+        self,
+        query_vec: List[float],
+        anns_field: str = "vector_text",
+        limit: int = 5,
+        filters: Optional[Dict[str, Union[str, int, float]]] = None,
+    ):
+        expr = None
+        if filters:
+            parts = []
+            for k, v in filters.items():
+                if isinstance(v, str):
+                    parts.append(f'{k} == "{v}"')
+                else:
+                    parts.append(f"{k} == {v}")
+            expr = " and ".join(parts)
+
+        return await self.vector_search(
+            query_vec=query_vec, anns_field=anns_field, limit=limit, expr=expr
+        )
+
+    # 通过主键获取milvus数据
+    async def get_by_id(self, pk: int):
+        return await asyncio.to_thread(
+            self.milvus_pool.query,
+            expr=f"id == {pk}",
+            output_fields=self.output_fields,
+        )
+
+    async def search_by_strategy(
+        self,
+        query_vec: List[float],
+        weight_map: Dict,
+        limit: int = 5,
+        expr: Optional[str] = None,
+        search_params: Optional[Dict[str, Any]] = None,
+    ):
+        async def _sub_search(vec, field):
+            return await asyncio.to_thread(
+                self.milvus_pool.search,
+                data=[vec],
+                anns_field=field,
+                param={"metric_type": "COSINE", "params": {"ef": 64}},
+                limit=limit,
+                expr=expr,
+                output_fields=self.output_fields,
+            )
+
+        tasks = {field: _sub_search(query_vec, field) for field in weight_map.keys()}
+        results = await asyncio.gather(*tasks.values())
+
+        scores = {}
+        for (field, weight), res in zip(weight_map.items(), results):
+            for hit in res[0]:
+                key = (hit.id, hit.entity.get("doc_id"), hit.entity.get("chunk_id"))
+                sim_score = 1 - hit.distance
+                scores[key] = scores.get(key, 0) + weight * sim_score
+
+        ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:limit]
+
+        return [
+            {"pk": k[0], "doc_id": k[1], "chunk_id": k[2], "score": v}
+            for k, v in ranked
+        ]

+ 89 - 2
routes/buleprint.py

@@ -2,9 +2,10 @@ import uuid
 
 from quart import Blueprint, jsonify, request
 
-from applications.config import DEFAULT_MODEL, LOCAL_MODEL_CONFIG, ChunkerConfig
+from applications.config import DEFAULT_MODEL, LOCAL_MODEL_CONFIG, ChunkerConfig, WEIGHT_MAP
 from applications.api import get_basic_embedding
 from applications.async_task import ChunkEmbeddingTask
+from applications.utils.milvus import MilvusSearcher
 
 server_bp = Blueprint("api", __name__, url_prefix="/api")
 
@@ -36,6 +37,92 @@ def server_routes(mysql_db, vector_db):
 
     @server_bp.route("/search", methods=["POST"])
     async def search():
-        pass
+        body = await request.get_json()
+        search_type = body.get("search_type")
+        if not search_type:
+            return jsonify({"error": "missing search_type"}), 400
+
+        searcher = MilvusSearcher(vector_db)
+
+        try:
+            # 统一参数
+            expr = body.get("expr")
+            search_params = body.get("search_params") or {"metric_type": "COSINE", "params": {"ef": 64}}
+            limit = body.get("limit", 5)
+            query = body.get("query")
+
+            # 定义不同搜索策略
+            async def by_pk_id():
+                pk_id = body.get("id")
+                if not pk_id:
+                    return {"error": "missing id"}
+                return await searcher.get_by_id(pk_id)
+
+            async def by_doc_id():
+                doc_id, chunk_id = body.get("doc_id"), body.get("chunk_id")
+                if not doc_id or chunk_id is None:
+                    return {"error": "missing doc_id or chunk_id"}
+                return await searcher.get_by_doc_and_chunk(doc_id, chunk_id)
+
+            async def by_vector():
+                if not query:
+                    return {"error": "missing query"}
+                field = body.get("field", "vector_text")
+                query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
+                return await searcher.vector_search(
+                    query_vec=query_vec,
+                    anns_field=field,
+                    expr=expr,
+                    search_params=search_params,
+                    limit=limit,
+                )
+
+            async def by_filter():
+                filter_map = body.get("filter_map")
+                if not filter_map:
+                    return {"error": "missing filter_map"}
+                return await searcher.filter_search(filter_map)
+
+            async def hybrid():
+                if not query:
+                    return {"error": "missing query"}
+                field = body.get("field", "vector_text")
+                query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
+                return await searcher.hybrid_search(
+                    query_vec=query_vec,
+                    anns_field=field,
+                    filters=body.get("filter_map"),
+                    limit=limit,
+                )
+
+            async def strategy():
+                if not query:
+                    return {"error": "missing query"}
+                query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
+                return await searcher.search_by_strategy(
+                    query_vec=query_vec,
+                    weight_map=body.get("weight_map", WEIGHT_MAP),
+                    expr=expr,
+                    limit=limit,
+                )
+
+            # dispatch table
+            handlers = {
+                "pk_id": by_pk_id,
+                "by_doc_id": by_doc_id,
+                "by_vector": by_vector,
+                "by_filter": by_filter,
+                "hybrid": hybrid,
+                "strategy": strategy,
+            }
+
+            if search_type not in handlers:
+                return jsonify({"error": "invalid search_type"}), 400
+
+            result = await handlers[search_type]()
+            return jsonify(result)
+
+        except Exception as e:
+            return jsonify({"error": str(e)}), 500
 
     return server_bp