Procházet zdrojové kódy

Merge branch 'feature/luojunhui/2025-09-18-search-engine-improve' of Server/llm_vector_server into master

luojunhui před 2 týdny
rodič
revize
863c02bc2e

+ 81 - 25
applications/async_task/chunk_task.py

@@ -1,23 +1,24 @@
 import asyncio
-import uuid
 from typing import List
 
 from applications.api import get_basic_embedding
 from applications.utils.async_utils import run_tasks_with_asyncio_task_group
-from applications.utils.mysql import ContentChunks, Contents
 from applications.utils.chunks import TopicAwareChunker, LLMClassifier
 from applications.utils.milvus import async_insert_chunk
+from applications.utils.mysql import ContentChunks, Contents
 from applications.config import Chunk, ChunkerConfig, DEFAULT_MODEL
+from applications.config import ELASTIC_SEARCH_INDEX
 
 
 class ChunkEmbeddingTask(TopicAwareChunker):
-    def __init__(self, mysql_pool, vector_pool, cfg: ChunkerConfig, doc_id):
+    def __init__(self, mysql_pool, vector_pool, cfg: ChunkerConfig, doc_id, es_pool):
         super().__init__(cfg, doc_id)
         self.content_chunk_processor = None
         self.contents_processor = None
         self.mysql_pool = mysql_pool
         self.vector_pool = vector_pool
         self.classifier = LLMClassifier()
+        self.es_client = es_pool
 
     @staticmethod
     async def get_embedding_list(text: str) -> List:
@@ -27,14 +28,16 @@ class ChunkEmbeddingTask(TopicAwareChunker):
         self.contents_processor = Contents(self.mysql_pool)
         self.content_chunk_processor = ContentChunks(self.mysql_pool)
 
-    async def process_content(
-        self, doc_id: str, text: str, text_type: int
+    async def _chunk_each_content(
+        self, doc_id: str, text: str, text_type: int, title: str, dataset_id: int
     ) -> List[Chunk]:
-        flag = await self.contents_processor.insert_content(doc_id, text, text_type)
+        flag = await self.contents_processor.insert_content(
+            doc_id, text, text_type, title, dataset_id
+        )
         if not flag:
             return []
         else:
-            raw_chunks = await self.chunk(text, text_type)
+            raw_chunks = await self.chunk(text, text_type, dataset_id)
             if not raw_chunks:
                 await self.contents_processor.update_content_status(
                     doc_id=doc_id,
@@ -50,7 +53,31 @@ class ChunkEmbeddingTask(TopicAwareChunker):
             )
             return raw_chunks
 
-    async def process_each_chunk(self, chunk: Chunk):
+    async def insert_into_es(self, milvus_id, chunk: Chunk) -> int:
+        docs = [
+            {
+                "_index": ELASTIC_SEARCH_INDEX,
+                "_id": milvus_id,
+                "_source": {
+                    "milvus_id": milvus_id,
+                    "doc_id": chunk.doc_id,
+                    "dataset_id": chunk.dataset_id,
+                    "chunk_id": chunk.chunk_id,
+                    "topic": chunk.topic,
+                    "domain": chunk.domain,
+                    "task_type": chunk.task_type,
+                    "text_type": chunk.text_type,
+                    "keywords": chunk.keywords,
+                    "concepts": chunk.concepts,
+                    "entities": chunk.entities,
+                    "status": chunk.status,
+                },
+            }
+        ]
+        resp = await self.es_client.bulk_insert(docs)
+        return resp["success"]
+
+    async def save_each_chunk(self, chunk: Chunk):
         # insert
         flag = await self.content_chunk_processor.insert_chunk(chunk)
         if not flag:
@@ -92,7 +119,30 @@ class ChunkEmbeddingTask(TopicAwareChunker):
             )
             return
 
-        await self.save_to_milvus(completion)
+        milvus_id = await self.save_to_milvus(completion)
+        if not milvus_id:
+            return
+
+        # 存储到 es 中
+        # acquire_lock
+        acquire_es_lock = await self.content_chunk_processor.update_es_status(
+            doc_id=chunk.doc_id,
+            chunk_id=chunk.chunk_id,
+            ori_status=self.INIT_STATUS,
+            new_status=self.PROCESSING_STATUS,
+        )
+        if not acquire_es_lock:
+            print(f"获取 es Lock Fail: {chunk.doc_id}--{chunk.chunk_id}")
+            return
+
+        insert_rows = await self.insert_into_es(milvus_id, completion)
+        final_status = self.FINISHED_STATUS if insert_rows else self.FAILED_STATUS
+        await self.content_chunk_processor.update_es_status(
+            doc_id=chunk.doc_id,
+            chunk_id=chunk.chunk_id,
+            ori_status=self.PROCESSING_STATUS,
+            new_status=final_status,
+        )
 
     async def save_to_milvus(self, chunk: Chunk):
         """
@@ -108,7 +158,7 @@ class ChunkEmbeddingTask(TopicAwareChunker):
         )
         if not acquire_lock:
             print(f"抢占-{chunk.doc_id}-{chunk.chunk_id}分块-embedding处理锁失败")
-            return
+            return None
         try:
             data = {
                 "doc_id": chunk.doc_id,
@@ -118,24 +168,25 @@ class ChunkEmbeddingTask(TopicAwareChunker):
                 "vector_questions": await self.get_embedding_list(
                     ",".join(chunk.questions)
                 ),
-                "topic": chunk.topic,
-                "domain": chunk.domain,
-                "task_type": chunk.task_type,
-                "summary": chunk.summary,
-                "keywords": chunk.keywords,
-                "entities": chunk.entities,
-                "concepts": chunk.concepts,
-                "questions": chunk.questions,
-                "topic_purity": chunk.topic_purity,
-                "tokens": chunk.tokens,
             }
-            await async_insert_chunk(self.vector_pool, data)
+            resp = await async_insert_chunk(self.vector_pool, data)
+            if not resp:
+                await self.content_chunk_processor.update_embedding_status(
+                    doc_id=chunk.doc_id,
+                    chunk_id=chunk.chunk_id,
+                    ori_status=self.PROCESSING_STATUS,
+                    new_status=self.FAILED_STATUS,
+                )
+                return None
+
             await self.content_chunk_processor.update_embedding_status(
                 doc_id=chunk.doc_id,
                 chunk_id=chunk.chunk_id,
                 ori_status=self.PROCESSING_STATUS,
                 new_status=self.FINISHED_STATUS,
             )
+            milvus_id = resp[0]
+            return milvus_id
         except Exception as e:
             await self.content_chunk_processor.update_embedding_status(
                 doc_id=chunk.doc_id,
@@ -144,28 +195,33 @@ class ChunkEmbeddingTask(TopicAwareChunker):
                 new_status=self.FAILED_STATUS,
             )
             print(f"存入向量数据库失败", e)
+            return None
 
     async def deal(self, data):
         text = data.get("text", "")
-        text = text.strip()
+        title = data.get("title", "")
+        text, title = text.strip(), title.strip()
         text_type = data.get("text_type", 1)
+        dataset_id = data.get("dataset_id", 0)  # 默认知识库 id 为 0
         if not text:
             return None
 
         self.init_processer()
 
         async def _process():
-            chunks = await self.process_content(self.doc_id, text, text_type)
+            chunks = await self._chunk_each_content(
+                self.doc_id, text, text_type, title, dataset_id
+            )
             if not chunks:
                 return
 
             # # dev
             # for chunk in chunks:
-            #     await self.process_each_chunk(chunk)
+            #     await self.save_each_chunk(chunk)
 
             await run_tasks_with_asyncio_task_group(
                 task_list=chunks,
-                handler=self.process_each_chunk,
+                handler=self.save_each_chunk,
                 description="处理单篇文章分块",
                 unit="chunk",
                 max_concurrency=10,

+ 7 - 1
applications/config/__init__.py

@@ -6,10 +6,12 @@ from .model_config import (
 )
 from .deepseek_config import DEEPSEEK_MODEL, DEEPSEEK_API_KEY
 from .base_chunk import Chunk, ChunkerConfig
-from .milvus_config import MILVUS_CONFIG
+from .elastic_search_config import ELASTIC_SEARCH_INDEX, ES_HOSTS, ES_PASSWORD
+from .milvus_config import MILVUS_CONFIG, BASE_MILVUS_SEARCH_PARAMS
 from .mysql_config import RAG_MYSQL_CONFIG
 from .weight_config import WEIGHT_MAP
 
+
 __all__ = [
     "DEFAULT_MODEL",
     "LOCAL_MODEL_CONFIG",
@@ -22,4 +24,8 @@ __all__ = [
     "MILVUS_CONFIG",
     "RAG_MYSQL_CONFIG",
     "WEIGHT_MAP",
+    "ES_HOSTS",
+    "ES_PASSWORD",
+    "ELASTIC_SEARCH_INDEX",
+    "BASE_MILVUS_SEARCH_PARAMS",
 ]

+ 2 - 0
applications/config/base_chunk.py

@@ -8,12 +8,14 @@ class Chunk:
     doc_id: str
     text: str
     tokens: int
+    dataset_id: int
     topic: str = ""
     domain: str = ""
     task_type: str = ""
     topic_purity: float = 1.0
     text_type: int = 1
     summary: str = ""
+    status: int = 1
     keywords: List[str] = field(default_factory=list)
     concepts: List[str] = field(default_factory=list)
     questions: List[str] = field(default_factory=list)

+ 5 - 0
applications/config/elastic_search_config.py

@@ -0,0 +1,5 @@
+ELASTIC_SEARCH_INDEX = "milvus_metadata"
+
+ES_PASSWORD = "elastic123@"
+
+ES_HOSTS = ["http://es-cn-ols4fypjx00020u36.public.elasticsearch.aliyuncs.com:9200"]

+ 31 - 0
applications/config/es_certs.crt

@@ -0,0 +1,31 @@
+-----BEGIN CERTIFICATE-----
+MIIFaTCCA1GgAwIBAgIUWHH9T8PVfiSyvT6S6NrAQ9iSLeEwDQYJKoZIhvcNAQEL
+BQAwPDE6MDgGA1UEAxMxRWxhc3RpY3NlYXJjaCBzZWN1cml0eSBhdXRvLWNvbmZp
+Z3VyYXRpb24gSFRUUCBDQTAeFw0yNTA3MDcwNzIwNTRaFw0yODA3MDYwNzIwNTRa
+MDwxOjA4BgNVBAMTMUVsYXN0aWNzZWFyY2ggc2VjdXJpdHkgYXV0by1jb25maWd1
+cmF0aW9uIEhUVFAgQ0EwggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQCb
+Y8E68+7S+hGKQX6vhyOxuCe3QyBHYlsxiSqGhi+WFx953u4SEMqrbqiyg2QquB9/
+ynjKo3Tvhn0OPjuJRytteKn9OZkVhUT1D5P6PFo0j8x1LIJZm551XRCnQUZ8jC0C
+REHy/JoKdT4YSCRIuXVTM5iM66vQ1t5Du4sb70mTygtc2DyXwgE4LkVnrHcwr2BZ
+3/O69WvF7Zd7WP93yEfUsLsAAQStaCYMeYyaY5K8UwIVcFyWKJ9lnDGbR9KmuXb9
+ipWqGw6aAYhmSs5gL+6xJ5dBpgMOqoBTvZpNniLA/phkelq9W2nAhBLFpRGRof8K
+5iKwjAN8gnBXeSVklBoL23QD5zfoVjz+5eaXWO4qP+90jbwf+vEg/duncDRONGtk
+TQd0Vr9NeO3Aye8PZsmmhKAaciaPWYyQO30omUq9kPsSUzZPu4k+CYb8qwVQCHpn
+Za19NkvERQ8hCQks08/ly5qDM+5lBxJQFQjhjtzSDQ/ybbarMmgaBxpCexiksRmP
+CQqVLW6IaLxUGEkIJqXRx8nmKUfK43vTBitOBFt5UcKob6+ikZLrqZ6xLY/jklE8
+Z1wt9I8ZdQ3L3X9EORgmQ+4KIu/JQfBdfAYtLaS6MYWhiZSaKaIhgfXiZQTO9YuW
+KrI5g+d2Yu2BYgIioLKo9LFWK1eTG2gNAGUI/+rqswIDAQABo2MwYTAdBgNVHQ4E
+FgQUab2kAtPlJHLirQvbThvIwJ7hbLwwHwYDVR0jBBgwFoAUab2kAtPlJHLirQvb
+ThvIwJ7hbLwwDwYDVR0TAQH/BAUwAwEB/zAOBgNVHQ8BAf8EBAMCAQYwDQYJKoZI
+hvcNAQELBQADggIBAF+wJ598Krfai5Br6Vq0Z1jj0JsU8Kij4t9D+89QPgI85/Mv
+zwj8xRgxx9RinKYdnzFJWrD9BITG2l3D0zcJhXfYUpq5HLP+c3zMwEMGzTLbgi70
+cpYqkTJ+g/Ah5WRYZRHJIMF6BVK6izCOO0J49eYC6AONNxG2HeeUvEL4cNnxpw8T
+NUe7v0FXe2iPLeE713h99ray0lBgI6J9QZqc/oEM47gHy+ByfWCv6Yw9qLlprppP
+taHz2VWnCAACDLzbDnYhemQDji86yrUTEdCT8at1jAwHSixgkm88nEBgxPHDuq8t
+thmiS6dELvXVUbyeWO7A/7zVde0Kndxe003OuYcX9I2IX7aIpC8sW/yY+alRhklq
+t9vF6g1qvsN69xXfW5yI5G31TYMUw/3ng0aVJfRFaXkEV2SWEZD+4sWoYC/GU7kK
+zlfaF22jTeul5qCKkN1k+i8K2lheEE3ZBC358W0RyvsrDwtXOra3VCpZ7qrez8OA
+/HeY6iISZQ7g0s209KjqOPqVGcI8B0p6KMh00AeWisU6E/wy1LNTxkf2IS9b88n6
+a3rj0TCycwhKOPTPB5pwlfbZNI00tGTFjqqi07SLqO9ZypsVkyR32G16JPJzk8Zw
+kngBZt6y9LtCMRVbyDuIDNq+fjtDjgxMI9bQXtve4bOuq8cZzcMjC6khz/Ja
+-----END CERTIFICATE-----

+ 5 - 0
applications/config/milvus_config.py

@@ -5,3 +5,8 @@ MILVUS_CONFIG = {
     "password": "Piaoquan@2025",
     "port": "19530",
 }
+
+BASE_MILVUS_SEARCH_PARAMS = {
+    "metric_type": "COSINE",
+    "params": {"ef": 64},
+}

+ 2 - 2
applications/config/model_config.py

@@ -6,7 +6,7 @@ LOCAL_MODEL_CONFIG = {
 
 DEFAULT_MODEL = "Qwen3-Embedding-4B"
 
-VLLM_SERVER_URL = "http://vllm-qwen:8000/v1/embeddings"
-# VLLM_SERVER_URL = "http://192.168.100.31:8000/v1/embeddings"
+# VLLM_SERVER_URL = "http://vllm-qwen:8000/v1/embeddings"
+VLLM_SERVER_URL = "http://192.168.100.31:8000/v1/embeddings"
 
 DEV_VLLM_SERVER_URL = "http://192.168.100.31:8000/v1/embeddings"

+ 4 - 0
applications/resource/__init__.py

@@ -0,0 +1,4 @@
+from .resource_manager import get_resource_manager
+from .resource_manager import init_resource_manager
+
+__all__ = ["get_resource_manager", "init_resource_manager"]

+ 87 - 0
applications/resource/resource_manager.py

@@ -0,0 +1,87 @@
+from pymilvus import connections, CollectionSchema, Collection
+
+from applications.utils.mysql import DatabaseManager
+from applications.utils.milvus.field import fields
+from applications.utils.elastic_search import AsyncElasticSearchClient
+
+
+class ResourceManager:
+    def __init__(self, es_index, es_hosts, es_password, milvus_config):
+        self.es_index = es_index
+        self.es_hosts = es_hosts
+        self.es_password = es_password
+        self.milvus_config = milvus_config
+
+        self.es_client: AsyncElasticSearchClient | None = None
+        self.milvus_client: Collection | None = None
+        self.mysql_client: DatabaseManager | None = None
+
+    async def load_milvus(self):
+        connections.connect("default", **self.milvus_config)
+
+        schema = CollectionSchema(
+            fields, description="Chunk multi-vector embeddings with metadata"
+        )
+        self.milvus_client = Collection(name="chunk_multi_embeddings_v2", schema=schema)
+
+        # create index
+        vector_index_params = {
+            "index_type": "IVF_FLAT",
+            "metric_type": "COSINE",
+            "params": {"M": 16, "efConstruction": 200},
+        }
+        self.milvus_client.create_index("vector_text", vector_index_params)
+        self.milvus_client.create_index("vector_summary", vector_index_params)
+        self.milvus_client.create_index("vector_questions", vector_index_params)
+        self.milvus_client.load()
+
+    async def startup(self):
+        # 初始化 Elasticsearch
+        self.es_client = AsyncElasticSearchClient(
+            index_name=self.es_index, hosts=self.es_hosts, password=self.es_password
+        )
+        if await self.es_client.es.ping():
+            print("✅ Elasticsearch connected")
+        else:
+            print("❌ Elasticsearch connection failed")
+
+        # 初始化 MySQL
+        self.mysql_client = DatabaseManager()
+        await self.mysql_client.init_pools()
+        print("✅ MySQL connected")
+
+        # 初始化 milvus
+        await self.load_milvus()
+        print("✅ Milvus loaded")
+
+    async def shutdown(self):
+        # 关闭 Elasticsearch
+        if self.es_client:
+            await self.es_client.close()
+            print("Elasticsearch closed")
+
+        # 关闭 Milvus
+        connections.disconnect("default")
+        print("Milvus closed")
+
+        # 关闭 MySQL
+        if self.mysql_client:
+            await self.mysql_client.close_pools()
+            print("Mysql closed")
+
+
+_resource_manager: ResourceManager | None = None
+
+
+def init_resource_manager(es_index, es_hosts, es_password, milvus_config):
+    global _resource_manager
+    if _resource_manager is None:
+        _resource_manager = ResourceManager(
+            es_index, es_hosts, es_password, milvus_config
+        )
+
+    return _resource_manager
+
+
+def get_resource_manager() -> ResourceManager:
+    return _resource_manager

+ 3 - 0
applications/search/__init__.py

@@ -0,0 +1,3 @@
+from .hybrid_search import HybridSearch
+
+__all__ = ["HybridSearch"]

+ 7 - 0
applications/search/base_search.py

@@ -0,0 +1,7 @@
+from applications.utils.milvus import MilvusSearch
+
+
+class BaseSearch(MilvusSearch):
+    def __init__(self, milvus_pool, es_pool):
+        super().__init__(milvus_pool)
+        self.es_pool = es_pool

+ 41 - 0
applications/search/hybrid_search.py

@@ -0,0 +1,41 @@
+from typing import List, Dict, Optional, Any
+from .base_search import BaseSearch
+
+from applications.utils.elastic_search import ElasticSearchStrategy
+
+
+class HybridSearch(BaseSearch):
+    def __init__(self, milvus_pool, es_pool):
+        super().__init__(milvus_pool, es_pool)
+        self.es_strategy = ElasticSearchStrategy(self.es_pool)
+
+    async def hybrid_search(
+        self,
+        filters: Dict[str, Any],  # 条件过滤
+        query_vec: List[float],  # query 的向量
+        anns_field: str = "vector_text",  # query指定的向量空间
+        search_params: Optional[Dict[str, Any]] = None,  # 向量距离方式
+        query_text: str = None,  # 是否通过 topic 倒排
+        _source=False,  # 是否返回元数据
+        es_size: int = 10000,  # es 第一层过滤数量
+        sort_by: str = None,  # 排序
+        milvus_size: int = 10,  # milvus粗排返回数量
+    ):
+        milvus_ids = await self.es_strategy.base_search(
+            filters=filters,
+            text_query=query_text,
+            _source=_source,
+            size=es_size,
+            sort_by=sort_by,
+        )
+        if not milvus_ids:
+            return {"results": []}
+        milvus_ids_list = ",".join(milvus_ids)
+        expr = f"id in [{milvus_ids_list}]"
+        return await self.base_vector_search(
+            query_vec=query_vec,
+            anns_field=anns_field,
+            limit=milvus_size,
+            expr=expr,
+            search_params=search_params,
+        )

+ 1 - 0
applications/utils/chunks/llm_classifier.py

@@ -48,6 +48,7 @@ class LLMClassifier:
             text=text,
             tokens=chunk.tokens,
             topic_purity=chunk.topic_purity,
+            dataset_id=chunk.dataset_id,
             summary=response.get("summary"),
             topic=response.get("topic"),
             domain=response.get("domain"),

+ 10 - 3
applications/utils/chunks/topic_aware_chunking.py

@@ -108,7 +108,11 @@ class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
         return np.stack(embs)
 
     def _pack_by_boundaries(
-        self, sentence_list: List[str], boundaries: List[int], text_type: int
+        self,
+        sentence_list: List[str],
+        boundaries: List[int],
+        text_type: int,
+        dataset_id: int,
     ) -> List[Chunk]:
         boundary_set = set(boundaries)
         chunks: List[Chunk] = []
@@ -141,6 +145,7 @@ class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
                 text=text,
                 tokens=tokens,
                 text_type=text_type,
+                dataset_id=dataset_id,
             )
             chunks.append(chunk)
             start = end + 1
@@ -167,14 +172,16 @@ class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
         finally:
             self.cfg.boundary_threshold = orig
 
-    async def chunk(self, text: str, text_type: int) -> List[Chunk]:
+    async def chunk(self, text: str, text_type: int, dataset_id: int) -> List[Chunk]:
         sentence_list = self.jieba_sent_tokenize(text)
         if not sentence_list:
             return []
 
         sentences_embeddings = await self._encode_batch(sentence_list)
         boundaries = self.detect_boundaries(sentence_list, sentences_embeddings)
-        raw_chunks = self._pack_by_boundaries(sentence_list, boundaries, text_type)
+        raw_chunks = self._pack_by_boundaries(
+            sentence_list, boundaries, text_type, dataset_id
+        )
         return raw_chunks
 
 

+ 7 - 0
applications/utils/elastic_search/__init__.py

@@ -0,0 +1,7 @@
+from applications.config import ELASTIC_SEARCH_INDEX, ES_HOSTS, ES_PASSWORD
+
+from .client import AsyncElasticSearchClient
+from .search_strategy import ElasticSearchStrategy
+
+
+__all__ = ["AsyncElasticSearchClient", "ElasticSearchStrategy"]

+ 77 - 0
applications/utils/elastic_search/client.py

@@ -0,0 +1,77 @@
+from elasticsearch import AsyncElasticsearch
+from elasticsearch.helpers import async_bulk
+
+from applications.utils.async_utils import run_tasks_with_asyncio_task_group
+
+
+class AsyncElasticSearchClient:
+
+    def __init__(self, index_name, hosts, password):
+        self.es = AsyncElasticsearch(hosts=hosts, basic_auth=("elastic", password))
+        self.index_name = index_name
+
+    async def create_index(self, settings, mappings):
+        if await self.es.ping():
+            print("ElasticSearch client is up and running")
+        else:
+            print("ElasticSearch client is not up and running")
+
+        exists = await self.es.indices.exists(index=self.index_name)
+        if exists:
+            print("index exists")
+            await self.es.indices.delete(index=self.index_name)
+            print("already delete index")
+        try:
+            await self.es.indices.create(
+                index=self.index_name, settings=settings, mappings=mappings
+            )
+            print("Index created successfully")
+        except Exception as e:
+            print("fail to create index, reason:", e)
+
+    async def search(self, query):
+        resp = await self.es.search(index=self.index_name, body=query)
+        return resp
+
+    async def update(self, obj):
+        return await self.es.update(
+            index=self.index_name, id=obj["es_id"], body=obj["doc"]
+        )
+
+    async def update_by_filed(self, field_name: str, field_value: str, doc: dict):
+        try:
+            # 先查出 doc_id
+            query = {"query": {"term": {field_name: field_value}}}
+            resp = await self.es.search(index=self.index_name, body=query)
+            if not resp["hits"]["hits"]:
+                print(f"No document found with {field_name}={field_value}")
+                return None
+
+            task_list = [
+                {"es_id": hit["_id"], "doc": doc} for hit in resp["hits"]["hits"]
+            ]
+
+            # update by ids
+            return await run_tasks_with_asyncio_task_group(
+                task_list=task_list,
+                handler=self.es.update,
+                description="update by filed",
+                unit="document",
+                max_concurrency=10,
+            )
+        except Exception as e:
+            print(f"fail to update by {field_name}={field_value}, reason:", e)
+            return None
+
+    async def bulk_insert(self, docs):
+        success, errors = await async_bulk(self.es, docs, request_timeout=10)
+        return {"success": success, "failed": len(errors), "errors": errors}
+
+    async def close(self):
+        await self.es.close()
+
+    async def __aenter__(self):
+        return self
+
+    async def __aexit__(self, exc_type, exc_val, exc_tb):
+        await self.es.close()

+ 30 - 0
applications/utils/elastic_search/create_index.py

@@ -0,0 +1,30 @@
+"""
+only use when create es index
+"""
+
+import asyncio
+
+from applications.config import ELASTIC_SEARCH_INDEX, ES_HOSTS, ES_PASSWORD
+from applications.utils.elastic_search.client import AsyncElasticSearchClient
+
+
+settings = {"number_of_shards": 3, "number_of_replicas": 1}
+
+
+mappings = {
+    "properties": {
+        "milvus_id": {"type": "keyword"},  # 向量数据库主键 id
+        "doc_id": {"type": "keyword"},  # 文档 ID
+        "chunk_id": {"type": "long"},  # chunk ID
+        "topic": {"type": "text", "fields": {"keyword": {"type": "keyword"}}},  # 主体
+        "domain": {"type": "keyword"},
+        "task_type": {"type": "keyword"},
+        "text_type": {"type": "keyword"},
+        "dataset_id": {"type": "keyword"},
+        "keywords": {"type": "keyword"},
+        "concepts": {"type": "keyword"},
+        "entities": {"type": "keyword"},
+        "status": {"type": "keyword"},
+        "created_at": {"type": "date", "format": "yyyy-MM-dd HH:mm:ss||epoch_millis"},
+    }
+}

+ 39 - 0
applications/utils/elastic_search/search_strategy.py

@@ -0,0 +1,39 @@
+from typing import List, Dict
+
+
+class ElasticSearchStrategy:
+    def __init__(self, es):
+        self.es = es
+
+    async def base_search(
+        self,
+        filters: Dict[str, List],
+        text_query: str = None,
+        _source=False,
+        size: int = 10000,
+        sort_by: str = None,
+    ) -> List:
+        must_clauses = []
+        for field, value in filters.items():
+            must_clauses.append({"terms": {field: value}})
+
+        if text_query:
+            must_clauses.append({"match": {"topic": text_query}})
+
+        query = {
+            "query": {"bool": {"must": must_clauses}},
+            "size": size,
+            "_source": _source,
+        }
+        try:
+            resp = await self.es.search(query=query)
+            return [
+                hit["_source"] if _source else hit["_id"]
+                for hit in resp["hits"]["hits"]
+            ]
+        except Exception as e:
+            print(f"search failed: {e}")
+            return []
+
+    async def search_strategy(self, query):
+        pass

+ 0 - 2
applications/utils/milvus/__init__.py

@@ -1,10 +1,8 @@
-from .collection import milvus_collection
 from .functions import async_insert_chunk, async_search_chunk
 from .search import MilvusSearch
 
 
 __all__ = [
-    "milvus_collection",
     "async_insert_chunk",
     "async_search_chunk",
     "MilvusSearch",

+ 26 - 26
applications/utils/milvus/collection.py

@@ -1,26 +1,26 @@
-from pymilvus import connections, CollectionSchema, Collection
-from applications.utils.milvus.field import fields
-from applications.config import MILVUS_CONFIG
-
-
-connections.connect("default", **MILVUS_CONFIG)
-
-schema = CollectionSchema(
-    fields, description="Chunk multi-vector embeddings with metadata"
-)
-milvus_collection = Collection(name="chunk_multi_embeddings", schema=schema)
-
-# create index
-vector_index_params = {
-    "index_type": "IVF_FLAT",
-    "metric_type": "COSINE",
-    "params": {"M": 16, "efConstruction": 200},
-}
-
-milvus_collection.create_index("vector_text", vector_index_params)
-milvus_collection.create_index("vector_summary", vector_index_params)
-milvus_collection.create_index("vector_questions", vector_index_params)
-
-milvus_collection.load()
-
-__all__ = ["milvus_collection"]
+# from pymilvus import connections, CollectionSchema, Collection
+# from applications.utils.milvus.field import fields
+# from applications.config import MILVUS_CONFIG
+#
+#
+# connections.connect("default", **MILVUS_CONFIG)
+#
+# schema = CollectionSchema(
+#     fields, description="Chunk multi-vector embeddings with metadata"
+# )
+# milvus_collection = Collection(name="chunk_multi_embeddings_v2", schema=schema)
+#
+# # create index
+# vector_index_params = {
+#     "index_type": "IVF_FLAT",
+#     "metric_type": "COSINE",
+#     "params": {"M": 16, "efConstruction": 200},
+# }
+#
+# milvus_collection.create_index("vector_text", vector_index_params)
+# milvus_collection.create_index("vector_summary", vector_index_params)
+# milvus_collection.create_index("vector_questions", vector_index_params)
+#
+# milvus_collection.load()
+#
+# __all__ = ["milvus_collection"]

+ 14 - 44
applications/utils/milvus/field.py

@@ -2,6 +2,7 @@ from pymilvus import FieldSchema, DataType
 
 # milvus 向量数据库
 fields = [
+    # 主键 ID
     FieldSchema(
         name="id",
         dtype=DataType.INT64,
@@ -9,61 +10,30 @@ fields = [
         auto_id=True,
         description="自增id",
     ),
+    # 文档 id 字段
     FieldSchema(
         name="doc_id", dtype=DataType.VARCHAR, max_length=64, description="文档id"
     ),
     FieldSchema(name="chunk_id", dtype=DataType.INT64, description="文档分块id"),
     # 三种向量字段
-    FieldSchema(name="vector_text", dtype=DataType.FLOAT_VECTOR, dim=2560),
-    FieldSchema(name="vector_summary", dtype=DataType.FLOAT_VECTOR, dim=2560),
-    FieldSchema(name="vector_questions", dtype=DataType.FLOAT_VECTOR, dim=2560),
-    # metadata
     FieldSchema(
-        name="topic", dtype=DataType.VARCHAR, max_length=255, description="主题"
+        name="vector_text",
+        dtype=DataType.FLOAT_VECTOR,
+        dim=2560,
+        description="chunk文本 embedding",
     ),
     FieldSchema(
-        name="domain", dtype=DataType.VARCHAR, max_length=100, description="领域"
+        name="vector_summary",
+        dtype=DataType.FLOAT_VECTOR,
+        dim=2560,
+        description="总结 embedding",
     ),
     FieldSchema(
-        name="task_type", dtype=DataType.VARCHAR, max_length=100, description="任务类型"
+        name="vector_questions",
+        dtype=DataType.FLOAT_VECTOR,
+        dim=2560,
+        description="衍生问题 embedding",
     ),
-    FieldSchema(
-        name="summary", dtype=DataType.VARCHAR, max_length=512, description="总结"
-    ),
-    FieldSchema(
-        name="keywords",
-        dtype=DataType.ARRAY,
-        element_type=DataType.VARCHAR,
-        max_length=100,
-        max_capacity=5,
-        description="关键词",
-    ),
-    FieldSchema(
-        name="concepts",
-        dtype=DataType.ARRAY,
-        element_type=DataType.VARCHAR,
-        max_length=100,
-        max_capacity=5,
-        description="主要知识点",
-    ),
-    FieldSchema(
-        name="questions",
-        dtype=DataType.ARRAY,
-        element_type=DataType.VARCHAR,
-        max_length=200,
-        max_capacity=5,
-        description="隐含问题",
-    ),
-    FieldSchema(
-        name="entities",
-        dtype=DataType.ARRAY,
-        element_type=DataType.VARCHAR,
-        max_length=200,
-        max_capacity=5,
-        description="命名实体",
-    ),
-    FieldSchema(name="topic_purity", dtype=DataType.FLOAT),
-    FieldSchema(name="tokens", dtype=DataType.INT64),
 ]
 
 

+ 4 - 3
applications/utils/milvus/functions.py

@@ -1,16 +1,17 @@
 import asyncio
-from typing import Dict
+from typing import Dict, List
 
 import pymilvus
 
 
-async def async_insert_chunk(collection: pymilvus.Collection, data: Dict):
+async def async_insert_chunk(collection: pymilvus.Collection, data: Dict) -> List[int]:
     """
     :param collection:
     :param data: insert data
     :return:
     """
-    await asyncio.to_thread(collection.insert, [data])
+    result = await asyncio.to_thread(collection.insert, [data])
+    return result.primary_keys
 
 
 async def async_search_chunk(

+ 3 - 73
applications/utils/milvus/search.py

@@ -5,18 +5,9 @@ from typing import List, Optional, Dict, Any, Union
 class MilvusBase:
 
     output_fields = [
+        "id",
         "doc_id",
         "chunk_id",
-        # "summary",
-        # "topic",
-        # "domain",
-        # "task_type",
-        # "keywords",
-        # "concepts",
-        # "questions",
-        # "entities",
-        # "tokens",
-        # "topic_purity",
     ]
 
     def __init__(self, milvus_pool):
@@ -43,8 +34,8 @@ class MilvusBase:
 
 class MilvusSearch(MilvusBase):
 
-    # 通过向量匹配
-    async def vector_search(
+    # 通过向量粗搜索
+    async def base_vector_search(
         self,
         query_vec: List[float],
         anns_field: str = "vector_text",
@@ -67,29 +58,6 @@ class MilvusSearch(MilvusBase):
         )
         return {"results": self.hits_to_json(response)}
 
-    # 混合搜索(向量 + metadata)
-    async def hybrid_search(
-        self,
-        query_vec: List[float],
-        anns_field: str = "vector_text",
-        limit: int = 5,
-        filters: Optional[Dict[str, Union[str, int, float]]] = None,
-    ):
-        expr = None
-        if filters:
-            parts = []
-            for k, v in filters.items():
-                if isinstance(v, str):
-                    parts.append(f'{k} == "{v}"')
-                else:
-                    parts.append(f"{k} == {v}")
-            expr = " and ".join(parts)
-
-        response = await self.vector_search(
-            query_vec=query_vec, anns_field=anns_field, limit=limit, expr=expr
-        )
-        return self.hits_to_json(response)
-
     async def search_by_strategy(
         self,
         query_vec: List[float],
@@ -125,41 +93,3 @@ class MilvusSearch(MilvusBase):
             {"pk": k[0], "doc_id": k[1], "chunk_id": k[2], "score": v}
             for k, v in ranked
         ]
-
-
-class MilvusQuery(MilvusBase):
-    # 通过doc_id + chunk_id 获取数据
-    async def get_by_doc_and_chunk(self, doc_id: str, chunk_id: int):
-        expr = f'doc_id == "{doc_id}" and chunk_id == {chunk_id}'
-        response = await asyncio.to_thread(
-            self.milvus_pool.query,
-            expr=expr,
-            output_fields=self.output_fields,
-        )
-        return self.hits_to_json(response)
-
-    # 只按 metadata 条件查询
-    async def filter_search(self, filters: Dict[str, Union[str, int, float]]):
-        exprs = []
-        for k, v in filters.items():
-            if isinstance(v, str):
-                exprs.append(f'{k} == "{v}"')
-            else:
-                exprs.append(f"{k} == {v}")
-        expr = " and ".join(exprs)
-        response = await asyncio.to_thread(
-            self.milvus_pool.query,
-            expr=expr,
-            output_fields=self.output_fields,
-        )
-        print(response)
-        return self.hits_to_json(response)
-
-    # 通过主键获取milvus数据
-    async def get_by_id(self, pk: int):
-        response = await asyncio.to_thread(
-            self.milvus_pool.query,
-            expr=f"id == {pk}",
-            output_fields=self.output_fields,
-        )
-        return self.hits_to_json(response)

+ 2 - 6
applications/utils/mysql/__init__.py

@@ -2,10 +2,6 @@ from .pool import DatabaseManager
 from .mapper import Contents, ContentChunks
 
 # 全局数据库管理器实例
-mysql_manager = DatabaseManager()
+# mysql_manager = DatabaseManager()
 
-__all__ = [
-    "mysql_manager",
-    "Contents",
-    "ContentChunks",
-]
+__all__ = ["Contents", "ContentChunks", "DatabaseManager"]

+ 18 - 6
applications/utils/mysql/mapper.py

@@ -17,13 +17,15 @@ class BaseMySQLClient:
 
 class Contents(BaseMySQLClient):
 
-    async def insert_content(self, doc_id, text, text_type):
+    async def insert_content(self, doc_id, text, text_type, title, dataset_id):
         query = """
             INSERT IGNORE INTO contents
-                (doc_id, text, text_type)
-            VALUES (%s, %s, %s);
+                (doc_id, text, text_type, title, dataset_id)
+            VALUES (%s, %s, %s, %s, %s);
         """
-        return await self.pool.async_save(query=query, params=(doc_id, text, text_type))
+        return await self.pool.async_save(
+            query=query, params=(doc_id, text, text_type, title, dataset_id)
+        )
 
     async def update_content_status(self, doc_id, ori_status, new_status):
         query = """
@@ -41,8 +43,8 @@ class ContentChunks(BaseMySQLClient):
     async def insert_chunk(self, chunk: Chunk) -> int:
         query = """
             INSERT IGNORE INTO content_chunks
-                (chunk_id, doc_id, text, tokens, topic_purity, text_type) 
-                VALUES (%s, %s, %s, %s, %s, %s);
+                (chunk_id, doc_id, text, tokens, topic_purity, text_type, dataset_id) 
+                VALUES (%s, %s, %s, %s, %s, %s, %s);
         """
         return await self.pool.async_save(
             query=query,
@@ -53,6 +55,7 @@ class ContentChunks(BaseMySQLClient):
                 chunk.tokens,
                 chunk.topic_purity,
                 chunk.text_type,
+                chunk.dataset_id,
             ),
         )
 
@@ -100,3 +103,12 @@ class ContentChunks(BaseMySQLClient):
                 ori_status,
             ),
         )
+
+    async def update_es_status(self, doc_id, chunk_id, ori_status, new_status):
+        query = """
+            UPDATE content_chunks SET es_status = %s
+            WHERE doc_id = %s AND chunk_id = %s AND es_status = %s;
+        """
+        return await self.pool.async_save(
+            query=query, params=(new_status, doc_id, chunk_id, ori_status)
+        )

+ 2 - 1
requirements.txt

@@ -16,6 +16,7 @@ pip-chill==1.0.3
 pymilvus==2.6.1
 pysocks==1.7.1
 quart-cors==0.8.0
-sentence-transformers==5.1.0
 tiktoken==0.11.0
 uvloop==0.21.0
+elasticsearch==8.17.2
+scikit-learn==1.7.2

+ 2 - 2
routes/__init__.py

@@ -1,3 +1,3 @@
-from .buleprint import server_routes
+from .buleprint import server_bp
 
-__all__ = ["server_routes"]
+__all__ = ["server_bp"]

+ 106 - 107
routes/buleprint.py

@@ -1,5 +1,6 @@
 import traceback
 import uuid
+from typing import Dict, Any
 
 from quart import Blueprint, jsonify, request
 
@@ -7,122 +8,120 @@ from applications.config import (
     DEFAULT_MODEL,
     LOCAL_MODEL_CONFIG,
     ChunkerConfig,
-    WEIGHT_MAP,
+    BASE_MILVUS_SEARCH_PARAMS,
 )
+from applications.resource import get_resource_manager
 from applications.api import get_basic_embedding
 from applications.api import get_img_embedding
 from applications.async_task import ChunkEmbeddingTask
-from applications.utils.milvus import MilvusSearch
+from applications.search import HybridSearch
+
 
 server_bp = Blueprint("api", __name__, url_prefix="/api")
 
 
-def server_routes(mysql_db, vector_db):
-
-    @server_bp.route("/embed", methods=["POST"])
-    async def embed():
-        body = await request.get_json()
-        text = body.get("text")
-        model_name = body.get("model", DEFAULT_MODEL)
-        if not LOCAL_MODEL_CONFIG.get(model_name):
-            return jsonify({"error": "error  model"})
-
-        embedding = await get_basic_embedding(text, model_name)
-        return jsonify({"embedding": embedding})
-
-    @server_bp.route("/img_embed", methods=["POST"])
-    async def img_embed():
-        body = await request.get_json()
-        url_list = body.get("url_list")
-        if not url_list:
-            return jsonify({"error": "error  url_list"})
-
-        embedding = await get_img_embedding(url_list)
-        return jsonify(embedding)
-
-    @server_bp.route("/chunk", methods=["POST"])
-    async def chunk():
-        body = await request.get_json()
-        text = body.get("text", "")
-        text = text.strip()
-        if not text:
-            return jsonify({"error": "error  text"})
-        doc_id = f"doc-{uuid.uuid4()}"
-        chunk_task = ChunkEmbeddingTask(
-            mysql_db, vector_db, cfg=ChunkerConfig(), doc_id=doc_id
-        )
-        doc_id = await chunk_task.deal(body)
-        return jsonify({"doc_id": doc_id})
-
-    @server_bp.route("/search", methods=["POST"])
-    async def search():
-        body = await request.get_json()
-        search_type = body.get("search_type")
-        if not search_type:
-            return jsonify({"error": "missing search_type"}), 400
-
-        searcher = MilvusSearch(vector_db)
-
-        try:
-            # 统一参数
-            expr = body.get("expr")
-            search_params = body.get("search_params") or {
-                "metric_type": "COSINE",
-                "params": {"ef": 64},
-            }
-            limit = body.get("limit", 50)
-            query = body.get("query")
-
-            async def by_vector():
-                if not query:
-                    return {"error": "missing query"}
-                field = body.get("field", "vector_text")
-                query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
-                return await searcher.vector_search(
-                    query_vec=query_vec,
-                    anns_field=field,
-                    expr=expr,
+@server_bp.route("/embed", methods=["POST"])
+async def embed():
+    body = await request.get_json()
+    text = body.get("text")
+    model_name = body.get("model", DEFAULT_MODEL)
+    if not LOCAL_MODEL_CONFIG.get(model_name):
+        return jsonify({"error": "error  model"})
+
+    embedding = await get_basic_embedding(text, model_name)
+    return jsonify({"embedding": embedding})
+
+
+@server_bp.route("/img_embed", methods=["POST"])
+async def img_embed():
+    body = await request.get_json()
+    url_list = body.get("url_list")
+    if not url_list:
+        return jsonify({"error": "error  url_list"})
+
+    embedding = await get_img_embedding(url_list)
+    return jsonify(embedding)
+
+
+@server_bp.route("/chunk", methods=["POST"])
+async def chunk():
+    body = await request.get_json()
+    text = body.get("text", "")
+    text = text.strip()
+    if not text:
+        return jsonify({"error": "error  text"})
+    resource = get_resource_manager()
+    doc_id = f"doc-{uuid.uuid4()}"
+    chunk_task = ChunkEmbeddingTask(
+        resource.mysql_client,
+        resource.milvus_client,
+        cfg=ChunkerConfig(),
+        doc_id=doc_id,
+        es_pool=resource.es_client,
+    )
+    doc_id = await chunk_task.deal(body)
+    return jsonify({"doc_id": doc_id})
+
+
+@server_bp.route("/search", methods=["POST"])
+async def search():
+    """
+    filters: Dict[str, Any], # 条件过滤
+    query_vec: List[float], # query 的向量
+    anns_field: str = "vector_text", # query指定的向量空间
+    search_params: Optional[Dict[str, Any]] = None, # 向量距离方式
+    query_text: str = None, #是否通过 topic 倒排
+    _source=False, # 是否返回元数据
+    es_size: int = 10000, #es 第一层过滤数量
+    sort_by: str = None, # 排序
+    milvus_size: int = 10 # milvus粗排返回数量
+    :return:
+    """
+    body = await request.get_json()
+
+    # 解析数据
+    search_type: str = body.get("search_type")
+    filters: Dict[str, Any] = body.get("filters", {})
+    anns_field: str = body.get("anns_field", "vector_text")
+    search_params: Dict[str, Any] = body.get("search_params", BASE_MILVUS_SEARCH_PARAMS)
+    query_text: str = body.get("query_text")
+    _source: bool = body.get("_source", False)
+    es_size: int = body.get("es_size", 10000)
+    sort_by: str = body.get("sort_by")
+    milvus_size: int = body.get("milvus", 20)
+    limit: int = body.get("limit", 10)
+    if not query_text:
+        return jsonify({"error": "error  query_text"})
+
+    query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
+    resource = get_resource_manager()
+    search_engine = HybridSearch(
+        milvus_pool=resource.milvus_client, es_pool=resource.es_client
+    )
+    try:
+        match search_type:
+            case "base":
+                response = await search_engine.base_vector_search(
+                    query_vec=query_vector,
+                    anns_field=anns_field,
                     search_params=search_params,
                     limit=limit,
                 )
-
-            async def hybrid():
-                if not query:
-                    return {"error": "missing query"}
-                field = body.get("field", "vector_text")
-                query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
-                return await searcher.hybrid_search(
-                    query_vec=query_vec,
-                    anns_field=field,
-                    filters=body.get("filter_map"),
-                    limit=limit,
-                )
-
-            async def strategy():
-                if not query:
-                    return {"error": "missing query"}
-                query_vec = await get_basic_embedding(text=query, model=DEFAULT_MODEL)
-                return await searcher.search_by_strategy(
-                    query_vec=query_vec,
-                    weight_map=body.get("weight_map", WEIGHT_MAP),
-                    expr=expr,
-                    limit=limit,
+                return jsonify(response), 200
+            case "hybrid":
+                response = await search_engine.hybrid_search(
+                    filters=filters,
+                    query_vec=query_vector,
+                    anns_field=anns_field,
+                    search_params=search_params,
+                    es_size=es_size,
+                    sort_by=sort_by,
+                    milvus_size=milvus_size,
                 )
-
-            # dispatch table
-            handlers = {
-                "by_vector": by_vector,
-                "hybrid": hybrid,
-                "strategy": strategy,
-            }
-
-            if search_type not in handlers:
-                return jsonify({"error": "invalid search_type"}), 400
-
-            result = await handlers[search_type]()
-            return jsonify(result)
-
-        except Exception as e:
-            return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
-
-    return server_bp
+                return jsonify(response), 200
+            case "strategy":
+                return jsonify({"error": "strategy not implemented"}), 405
+            case _:
+                return jsonify({"error": "error  search_type"}), 200
+    except Exception as e:
+        return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500

+ 20 - 13
vector_app.py

@@ -2,31 +2,38 @@ import jieba
 from quart import Quart
 
 from applications.config import LOCAL_MODEL_CONFIG, DEFAULT_MODEL
-from applications.utils.milvus import milvus_collection
-from applications.utils.mysql import mysql_manager
-from routes import server_routes
+from applications.config import ES_HOSTS, ES_PASSWORD, ELASTIC_SEARCH_INDEX
+from applications.config import MILVUS_CONFIG
+from applications.resource import init_resource_manager
 
 app = Quart(__name__)
 
+# 初始化
 MODEL_PATH = LOCAL_MODEL_CONFIG[DEFAULT_MODEL]
 
-# 注册路由
-app_route = server_routes(mysql_manager, milvus_collection)
-app.register_blueprint(app_route)
+resource_manager = init_resource_manager(
+    es_hosts=ES_HOSTS,
+    es_index=ELASTIC_SEARCH_INDEX,
+    es_password=ES_PASSWORD,
+    milvus_config=MILVUS_CONFIG,
+)
 
 
 @app.before_serving
 async def startup():
-    print("Starting application...")
-    await mysql_manager.init_pools()
-    print("Mysql pools init successfully")
-
-    print("Loading jieba dictionary...")
+    await resource_manager.startup()
+    print("Resource manager is ready.")
     jieba.initialize()
     print("Jieba dictionary loaded successfully")
 
 
 @app.after_serving
 async def shutdown():
-    print("Shutting down application...")
-    await mysql_manager.close_pools()
+    await resource_manager.shutdown()
+    print("Resource manager is Down.")
+
+
+# 注册路由
+from routes import server_bp
+
+app.register_blueprint(server_bp)