Bläddra i källkod

Merge branch 'feature/luojunhui/2025-09-15-add-vector-search' of Server/llm_vector_server into master

luojunhui 3 veckor sedan
förälder
incheckning
57b8272d2a

+ 3 - 1
applications/async_task/chunk_task.py

@@ -27,7 +27,9 @@ class ChunkEmbeddingTask(TopicAwareChunker):
         self.contents_processor = Contents(self.mysql_pool)
         self.content_chunk_processor = ContentChunks(self.mysql_pool)
 
-    async def process_content(self, doc_id: str, text: str, text_type: int) -> List[Chunk]:
+    async def process_content(
+        self, doc_id: str, text: str, text_type: int
+    ) -> List[Chunk]:
         flag = await self.contents_processor.insert_content(doc_id, text, text_type)
         if not flag:
             return []

+ 3 - 1
applications/config/__init__.py

@@ -8,6 +8,7 @@ from .deepseek_config import DEEPSEEK_MODEL, DEEPSEEK_API_KEY
 from .base_chunk import Chunk, ChunkerConfig
 from .milvus_config import MILVUS_CONFIG
 from .mysql_config import RAG_MYSQL_CONFIG
+from .weight_config import WEIGHT_MAP
 
 __all__ = [
     "DEFAULT_MODEL",
@@ -19,5 +20,6 @@ __all__ = [
     "Chunk",
     "ChunkerConfig",
     "MILVUS_CONFIG",
-    "RAG_MYSQL_CONFIG"
+    "RAG_MYSQL_CONFIG",
+    "WEIGHT_MAP",
 ]

+ 2 - 1
applications/config/base_chunk.py

@@ -1,6 +1,7 @@
 from typing import List, Dict, Any
 from dataclasses import dataclass, field, asdict
 
+
 @dataclass
 class Chunk:
     chunk_id: int
@@ -28,4 +29,4 @@ class ChunkerConfig:
     enable_adaptive_boundary: bool = True
     enable_kg: bool = True
     topic_purity_floor: float = 0.8
-    kg_topk: int = 3
+    kg_topk: int = 3

+ 3 - 4
applications/config/milvus_config.py

@@ -1,8 +1,7 @@
-
 MILVUS_CONFIG = {
     # "host": "c-981be0ee7225467b-internal.milvus.aliyuncs.com", # 内网
-    "host": "c-981be0ee7225467b.milvus.aliyuncs.com", # 公网
+    "host": "c-981be0ee7225467b.milvus.aliyuncs.com",  # 公网
     "user": "root",
     "password": "Piaoquan@2025",
-    "port": "19530"
-}
+    "port": "19530",
+}

+ 1 - 0
applications/config/model_config.py

@@ -7,5 +7,6 @@ LOCAL_MODEL_CONFIG = {
 DEFAULT_MODEL = "Qwen3-Embedding-4B"
 
 VLLM_SERVER_URL = "http://vllm-qwen:8000/v1/embeddings"
+# VLLM_SERVER_URL = "http://192.168.100.31:8000/v1/embeddings"
 
 DEV_VLLM_SERVER_URL = "http://192.168.100.31:8000/v1/embeddings"

+ 1 - 1
applications/config/mysql_config.py

@@ -7,4 +7,4 @@ RAG_MYSQL_CONFIG = {
     "charset": "utf8mb4",
     "minsize": 5,
     "maxsize": 20,
-}
+}

+ 7 - 0
applications/config/weight_config.py

@@ -0,0 +1,7 @@
+# weight config
+
+WEIGHT_MAP = {
+    "vector_text": 0.6,
+    "question_text": 0.3,
+    "summary_text": 0.6,
+}

+ 1 - 1
applications/utils/chunks/llm_classifier.py

@@ -56,4 +56,4 @@ class LLMClassifier:
             keywords=response.get("keywords", []),
             questions=response.get("questions", []),
             entities=response.get("entities", []),
-        )
+        )

+ 5 - 1
applications/utils/chunks/topic_aware_chunking.py

@@ -136,7 +136,11 @@ class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
             tokens = num_tokens(text)
             chunk_id += 1
             chunk = Chunk(
-                doc_id=self.doc_id, chunk_id=chunk_id, text=text, tokens=tokens, text_type=text_type
+                doc_id=self.doc_id,
+                chunk_id=chunk_id,
+                text=text,
+                tokens=tokens,
+                text_type=text_type,
             )
             chunks.append(chunk)
             start = end + 1

+ 7 - 1
applications/utils/milvus/__init__.py

@@ -1,5 +1,11 @@
 from .collection import milvus_collection
 from .functions import async_insert_chunk, async_search_chunk
+from .search import MilvusSearch
 
 
-__all__ = ["milvus_collection", "async_insert_chunk", "async_search_chunk"]
+__all__ = [
+    "milvus_collection",
+    "async_insert_chunk",
+    "async_search_chunk",
+    "MilvusSearch",
+]

+ 1 - 3
applications/utils/milvus/collection.py

@@ -14,9 +14,7 @@ milvus_collection = Collection(name="chunk_multi_embeddings", schema=schema)
 vector_index_params = {
     "index_type": "IVF_FLAT",
     "metric_type": "COSINE",
-    "params": {
-        "M": 16, "efConstruction": 200
-    }
+    "params": {"M": 16, "efConstruction": 200},
 }
 
 milvus_collection.create_index("vector_text", vector_index_params)

+ 1 - 1
applications/utils/milvus/field.py

@@ -54,7 +54,7 @@ fields = [
         max_capacity=5,
         description="隐含问题",
     ),
-FieldSchema(
+    FieldSchema(
         name="entities",
         dtype=DataType.ARRAY,
         element_type=DataType.VARCHAR,

+ 1 - 2
applications/utils/milvus/functions.py

@@ -10,8 +10,7 @@ async def async_insert_chunk(collection: pymilvus.Collection, data: Dict):
     :param data: insert data
     :return:
     """
-    res = await asyncio.to_thread(collection.insert, [data])
-    print(res)
+    await asyncio.to_thread(collection.insert, [data])
 
 
 async def async_search_chunk(

+ 165 - 0
applications/utils/milvus/search.py

@@ -0,0 +1,165 @@
+import asyncio
+from typing import List, Optional, Dict, Any, Union
+
+
+class MilvusBase:
+
+    output_fields = [
+        "doc_id",
+        "chunk_id",
+        # "summary",
+        # "topic",
+        # "domain",
+        # "task_type",
+        # "keywords",
+        # "concepts",
+        # "questions",
+        # "entities",
+        # "tokens",
+        # "topic_purity",
+    ]
+
+    def __init__(self, milvus_pool):
+        self.milvus_pool = milvus_pool
+
+    @staticmethod
+    def hits_to_json(hits):
+        if not hits:
+            return []
+
+        special_keys = {"entities", "concepts", "questions", "keywords"}
+        return [
+            {
+                "pk": hit.id,
+                "score": hit.distance,
+                **{
+                    key: list(value) if key in special_keys else value
+                    for key, value in (hit.get("entity", {}) or {}).items()
+                },
+            }
+            for hit in hits[0]
+        ]
+
+
+class MilvusSearch(MilvusBase):
+
+    # 通过向量匹配
+    async def vector_search(
+        self,
+        query_vec: List[float],
+        anns_field: str = "vector_text",
+        limit: int = 5,
+        expr: Optional[str] = None,
+        search_params: Optional[Dict[str, Any]] = None,
+    ):
+        """向量搜索,可选过滤"""
+        if search_params is None:
+            search_params = {"metric_type": "COSINE", "params": {"ef": 64}}
+
+        response = await asyncio.to_thread(
+            self.milvus_pool.search,
+            data=[query_vec],
+            anns_field=anns_field,
+            param=search_params,
+            limit=limit,
+            expr=expr,
+            output_fields=self.output_fields,
+        )
+        return {"results": self.hits_to_json(response)}
+
+    # 混合搜索(向量 + metadata)
+    async def hybrid_search(
+        self,
+        query_vec: List[float],
+        anns_field: str = "vector_text",
+        limit: int = 5,
+        filters: Optional[Dict[str, Union[str, int, float]]] = None,
+    ):
+        expr = None
+        if filters:
+            parts = []
+            for k, v in filters.items():
+                if isinstance(v, str):
+                    parts.append(f'{k} == "{v}"')
+                else:
+                    parts.append(f"{k} == {v}")
+            expr = " and ".join(parts)
+
+        response = await self.vector_search(
+            query_vec=query_vec, anns_field=anns_field, limit=limit, expr=expr
+        )
+        return self.hits_to_json(response)
+
+    async def search_by_strategy(
+        self,
+        query_vec: List[float],
+        weight_map: Dict,
+        limit: int = 5,
+        expr: Optional[str] = None,
+        search_params: Optional[Dict[str, Any]] = None,
+    ):
+        async def _sub_search(vec, field):
+            return await asyncio.to_thread(
+                self.milvus_pool.search,
+                data=[vec],
+                anns_field=field,
+                param={"metric_type": "COSINE", "params": {"ef": 64}},
+                limit=limit,
+                expr=expr,
+                output_fields=self.output_fields,
+            )
+
+        tasks = {field: _sub_search(query_vec, field) for field in weight_map.keys()}
+        results = await asyncio.gather(*tasks.values())
+
+        scores = {}
+        for (field, weight), res in zip(weight_map.items(), results):
+            for hit in res[0]:
+                key = (hit.id, hit.entity.get("doc_id"), hit.entity.get("chunk_id"))
+                sim_score = 1 - hit.distance
+                scores[key] = scores.get(key, 0) + weight * sim_score
+
+        ranked = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:limit]
+
+        return [
+            {"pk": k[0], "doc_id": k[1], "chunk_id": k[2], "score": v}
+            for k, v in ranked
+        ]
+
+
+class MilvusQuery(MilvusBase):
+    # 通过doc_id + chunk_id 获取数据
+    async def get_by_doc_and_chunk(self, doc_id: str, chunk_id: int):
+        expr = f'doc_id == "{doc_id}" and chunk_id == {chunk_id}'
+        response = await asyncio.to_thread(
+            self.milvus_pool.query,
+            expr=expr,
+            output_fields=self.output_fields,
+        )
+        return self.hits_to_json(response)
+
+    # 只按 metadata 条件查询
+    async def filter_search(self, filters: Dict[str, Union[str, int, float]]):
+        exprs = []
+        for k, v in filters.items():
+            if isinstance(v, str):
+                exprs.append(f'{k} == "{v}"')
+            else:
+                exprs.append(f"{k} == {v}")
+        expr = " and ".join(exprs)
+        response = await asyncio.to_thread(
+            self.milvus_pool.query,
+            expr=expr,
+            output_fields=self.output_fields,
+        )
+        print(response)
+        return self.hits_to_json(response)
+
+    # 通过主键获取milvus数据
+    async def get_by_id(self, pk: int):
+        response = await asyncio.to_thread(
+            self.milvus_pool.query,
+            expr=f"id == {pk}",
+            output_fields=self.output_fields,
+        )
+        return self.hits_to_json(response)

+ 2 - 2
applications/utils/mysql/mapper.py

@@ -52,7 +52,7 @@ class ContentChunks(BaseMySQLClient):
                 chunk.text,
                 chunk.tokens,
                 chunk.topic_purity,
-                chunk.text_type
+                chunk.text_type,
             ),
         )
 
@@ -97,6 +97,6 @@ class ContentChunks(BaseMySQLClient):
                 json.dumps(chunk.entities),
                 chunk.doc_id,
                 chunk.chunk_id,
-                ori_status
+                ori_status,
             ),
         )

+ 79 - 3
routes/buleprint.py

@@ -1,10 +1,17 @@
+import traceback
 import uuid
 
 from quart import Blueprint, jsonify, request
 
-from applications.config import DEFAULT_MODEL, LOCAL_MODEL_CONFIG, ChunkerConfig
+from applications.config import (
+    DEFAULT_MODEL,
+    LOCAL_MODEL_CONFIG,
+    ChunkerConfig,
+    WEIGHT_MAP,
+)
 from applications.api import get_basic_embedding
 from applications.async_task import ChunkEmbeddingTask
+from applications.utils.milvus import MilvusSearch
 
 server_bp = Blueprint("api", __name__, url_prefix="/api")
 
@@ -30,12 +37,81 @@ def server_routes(mysql_db, vector_db):
         if not text:
             return jsonify({"error": "error  text"})
         doc_id = f"doc-{uuid.uuid4()}"
-        chunk_task = ChunkEmbeddingTask(mysql_db, vector_db, cfg=ChunkerConfig(), doc_id=doc_id)
+        chunk_task = ChunkEmbeddingTask(
+            mysql_db, vector_db, cfg=ChunkerConfig(), doc_id=doc_id
+        )
         doc_id = await chunk_task.deal(body)
         return jsonify({"doc_id": doc_id})
 
     @server_bp.route("/search", methods=["POST"])
     async def search():
-        pass
+        body = await request.get_json()
+        search_type = body.get("search_type")
+        if not search_type:
+            return jsonify({"error": "missing search_type"}), 400
+
+        searcher = MilvusSearch(vector_db)
+
+        try:
+            # 统一参数
+            expr = body.get("expr")
+            search_params = body.get("search_params") or {
+                "metric_type": "COSINE",
+                "params": {"ef": 64},
+            }
+            limit = body.get("limit", 50)
+            query = body.get("query")
+
+            async def by_vector():
+                if not query:
+                    return {"error": "missing query"}
+                field = body.get("field", "vector_text")
+                query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
+                return await searcher.vector_search(
+                    query_vec=query_vec,
+                    anns_field=field,
+                    expr=expr,
+                    search_params=search_params,
+                    limit=limit,
+                )
+
+            async def hybrid():
+                if not query:
+                    return {"error": "missing query"}
+                field = body.get("field", "vector_text")
+                query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
+                return await searcher.hybrid_search(
+                    query_vec=query_vec,
+                    anns_field=field,
+                    filters=body.get("filter_map"),
+                    limit=limit,
+                )
+
+            async def strategy():
+                if not query:
+                    return {"error": "missing query"}
+                query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
+                return await searcher.search_by_strategy(
+                    query_vec=query_vec,
+                    weight_map=body.get("weight_map", WEIGHT_MAP),
+                    expr=expr,
+                    limit=limit,
+                )
+
+            # dispatch table
+            handlers = {
+                "by_vector": by_vector,
+                "hybrid": hybrid,
+                "strategy": strategy,
+            }
+
+            if search_type not in handlers:
+                return jsonify({"error": "invalid search_type"}), 400
+
+            result = await handlers[search_type]()
+            return jsonify(result)
+
+        except Exception as e:
+            return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
 
     return server_bp

+ 3 - 1
vector_app.py

@@ -14,6 +14,7 @@ MODEL_PATH = LOCAL_MODEL_CONFIG[DEFAULT_MODEL]
 app_route = server_routes(mysql_manager, milvus_collection)
 app.register_blueprint(app_route)
 
+
 @app.before_serving
 async def startup():
     print("Starting application...")
@@ -24,7 +25,8 @@ async def startup():
     jieba.initialize()
     print("Jieba dictionary loaded successfully")
 
+
 @app.after_serving
 async def shutdown():
     print("Shutting down application...")
-    await mysql_manager.close_pools()
+    await mysql_manager.close_pools()