Przeglądaj źródła

Merge branch 'feature/luojunhui/2025-09-29-graph-query' of Server/rag_server into master

luojunhui 1 tydzień temu
rodzic
commit
e65becabd0

+ 83 - 1
applications/search/hybrid_search.py

@@ -1,13 +1,16 @@
 from typing import List, Dict, Optional, Any
+
 from .base_search import BaseSearch
 
+from applications.utils.neo4j import AsyncGraphExpansion
 from applications.utils.elastic_search import ElasticSearchStrategy
 
 
 class HybridSearch(BaseSearch):
-    def __init__(self, milvus_pool, es_pool):
+    def __init__(self, milvus_pool, es_pool, graph_pool):
         super().__init__(milvus_pool, es_pool)
         self.es_strategy = ElasticSearchStrategy(self.es_pool)
+        self.graph_expansion = AsyncGraphExpansion(driver=graph_pool)
 
     async def hybrid_search(
         self,
@@ -39,3 +42,82 @@ class HybridSearch(BaseSearch):
             expr=expr,
             search_params=search_params,
         )
+
+    async def hybrid_search_with_graph(
+        self,
+        filters: Dict[str, Any],  # 条件过滤
+        query_vec: List[float],  # query 的向量
+        anns_field: str = "vector_text",  # query指定的向量空间
+        search_params: Optional[Dict[str, Any]] = None,  # 向量距离方式
+        query_text: str = None,  # 是否通过 topic 倒排
+        _source=False,  # 是否返回元数据
+        es_size: int = 10000,  # es 第一层过滤数量
+        sort_by: str = None,  # 排序
+        milvus_size: int = 10,  # milvus粗排返回数量
+        co_occurrence_fields: Dict[str, Any] = None,  # 共现字段
+        shortest_path_fields: Dict[str, Any] = None,  # 最短之间的 chunks
+    ):
+        # step1, use elastic_search to filter chunks
+        es_milvus_ids = await self.es_strategy.base_search(
+            filters=filters,
+            text_query=query_text,
+            _source=_source,
+            size=es_size,
+            sort_by=sort_by,
+        )
+        # step2, use graph to get co_occurrence chunks
+        if not co_occurrence_fields:
+            co_occurrence_ids = []
+        else:
+            # 测试版本先只用实体
+            seed_label = "Entity"
+            name = co_occurrence_fields.get(seed_label)
+            if not name:
+                co_occurrence_ids = []
+            else:
+                co_occurrence_ids = await self.graph_expansion.co_occurrence(
+                    seed_name=name, seed_label=seed_label
+                )
+
+        # step3, 查询两个 chunk 之间的chunks
+        if not shortest_path_fields:
+            shortest_path_ids = []
+        else:
+            shortest_path_ids = await self.graph_expansion.shortest_path_chunks(
+                a_name=shortest_path_fields.get("a_name"),
+                a_label=shortest_path_fields.get("a_label"),
+                b_name=shortest_path_fields.get("b_name"),
+                b_label=shortest_path_fields.get("b_label"),
+            )
+        print("es:", es_milvus_ids)
+        print("co:", co_occurrence_ids)
+        print("shortest:", shortest_path_ids)
+
+        # step3, merge 上述 ids
+        final_milvus_ids = list(
+            set(shortest_path_ids + co_occurrence_ids + es_milvus_ids)
+        )
+
+        # step4, 通过向量获取候选集
+        if not final_milvus_ids:
+            return {"results": []}
+
+        milvus_ids_list = ",".join(final_milvus_ids)
+
+        expr = f"id in [{milvus_ids_list}]"
+        return await self.base_vector_search(
+            query_vec=query_vec,
+            anns_field=anns_field,
+            limit=milvus_size,
+            expr=expr,
+            search_params=search_params,
+        )
+
+    async def expand_with_graph(
+        self, milvus_ids: List[int], limit: int
+    ) -> List[Dict[str, Any]]:
+        """拓展字段"""
+        expanded_chunks = await self.graph_expansion.expand_candidates(
+            seed_ids=milvus_ids, limit=limit
+        )
+        return expanded_chunks

+ 3 - 1
applications/utils/neo4j/__init__.py

@@ -1,7 +1,9 @@
-from .repository import AsyncNeo4jRepository
+from .graph_expansion import AsyncGraphExpansion
 from .models import Document, ChunkRelations, GraphChunk
+from .repository import AsyncNeo4jRepository
 
 __all__ = [
+    "AsyncGraphExpansion",
     "AsyncNeo4jRepository",
     "Document",
     "ChunkRelations",

+ 291 - 0
applications/utils/neo4j/graph_expansion.py

@@ -0,0 +1,291 @@
+import json
+from typing import List, Optional, Dict, Any, Iterable
+
+
+class AsyncGraphExpansion:
+    _REL_ANY = ":HAS_ENTITY|HAS_CONCEPT|HAS_TOPIC"
+
+    def __init__(self, driver, database: str = "neo4j"):
+        self.driver = driver
+        self.database = database
+
+    # ========== 共现 co-occurrence ==========
+    async def co_occurrence(
+            self,
+            seed_name: str,
+            seed_label: str,
+            other_names: Optional[List[str]] = None,
+            limit: int = 500,
+    ) -> List[str]:
+        """
+        找与种子要素共现的 chunk
+        """
+        REL_ANY = ":HAS_ENTITY|HAS_CONCEPT|HAS_TOPIC"  # 非弃用:只在第一个关系类型前加冒号
+
+        async with self.driver.session(database=self.database) as session:
+            if not other_names:
+                # 一次往返:显式作用域子查询 + 非弃用关系“或”写法
+                cypher = f"""
+                MATCH (seed:`{seed_label}` {{name:$seed_name}})
+                CALL (seed) {{
+                    MATCH (seed)<-[{REL_ANY}]-(gc:GraphChunk)
+                    MATCH (gc)-[{REL_ANY}]->(other)
+                    WHERE other <> seed
+                    WITH other.name AS name, count(*) AS co_freq
+                    ORDER BY co_freq DESC
+                    RETURN collect(name)[..20] AS sel_names
+                }}
+                WITH seed, sel_names
+                WHERE size(sel_names) > 0
+                MATCH (seed)<-[{REL_ANY}]-(gc2:GraphChunk)
+                MATCH (gc2)-[{REL_ANY}]->(o2)
+                WHERE o2.name IN sel_names
+                RETURN DISTINCT gc2.milvus_id AS milvus_id
+                LIMIT $limit
+                """
+                records = await session.run(
+                    cypher, {"seed_name": seed_name, "limit": limit}
+                )
+                return [str(r["milvus_id"]) async for r in records]
+
+            # 已提供 other_names:直接扩展(同样使用非弃用关系“或”写法)
+            cypher = f"""
+            MATCH (seed:`{seed_label}` {{name:$seed_name}})
+                  <-[{REL_ANY}]-(gc:GraphChunk)
+            MATCH (gc)-[{REL_ANY}]->(o)
+            WHERE o.name IN $other_names
+            RETURN DISTINCT gc.milvus_id AS milvus_id
+            LIMIT $limit
+            """
+            records = await session.run(
+                cypher,
+                {"seed_name": seed_name, "other_names": other_names, "limit": limit},
+            )
+            return [str(r["milvus_id"]) async for r in records]
+
+    # ========== 多种子(每个 seed 自带标签)共现(性能优化版) ==========
+    async def co_occurrence_multi_mixed_labels(
+        self,
+        seeds: List[
+            Dict[str, str]
+        ],  # [{"name": "...", "label": "Entity|Concept|Topic"}, ...]
+        other_names: Optional[List[str]] = None,
+        limit: int = 500,
+        top_k: int = 20,
+        min_support: int = 1,  # 至少与多少个不同 seed 共现
+    ) -> List[str]:
+        if not seeds:
+            return []
+
+        # 1) 标签白名单 + 去重分桶(减少 Cypher 分支中过滤与扫描)
+        allowed = {"Entity", "Concept", "Topic"}
+        seed_entities = sorted({s["name"] for s in seeds if s.get("label") == "Entity"})
+        seed_concepts = sorted(
+            {s["name"] for s in seeds if s.get("label") == "Concept"}
+        )
+        seed_topics = sorted({s["name"] for s in seeds if s.get("label") == "Topic"})
+        if not (set(x.get("label") for x in seeds) <= allowed):
+            raise ValueError("Seed label must be one of Entity/Concept/Topic")
+
+        async with self.driver.session(database=self.database) as session:
+            # 2) 统计跨 seeds 的共现 TopN(当未显式提供 other_names 时)
+            if not other_names:
+                query_top = """
+                // ========== 先收敛到与 seeds 共现的 gc + sname ==========
+                CALL {
+                  WITH $seed_entities AS se
+                  UNWIND se AS sname
+                  MATCH (:Entity {name:sname})<-[:HAS_ENTITY]-(gc:GraphChunk)
+                  RETURN gc, sname
+                  UNION
+                  WITH $seed_concepts AS sc
+                  UNWIND sc AS sname
+                  MATCH (:Concept {name:sname})<-[:HAS_CONCEPT]-(gc:GraphChunk)
+                  RETURN gc, sname
+                  UNION
+                  WITH $seed_topics AS st
+                  UNWIND st AS sname
+                  MATCH (:Topic {name:sname})<-[:HAS_TOPIC]-(gc:GraphChunk)
+                  RETURN gc, sname
+                }
+
+                // ========== 在这些 gc 上一次性发散到 other ==========
+                MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(other)
+                // 排除把种子本身当作 other 的情况(基于标签分桶,仍然命中索引)
+                WHERE NOT (
+                  (other:Entity  AND other.name IN $seed_entities) OR
+                  (other:Concept AND other.name IN $seed_concepts) OR
+                  (other:Topic   AND other.name IN $seed_topics)
+                )
+                WITH other.name AS name, gc, sname
+                // 跨 gc 汇总频次;跨 sname 汇总支持度(命中的不同 seed 数)
+                WITH name,
+                     count(DISTINCT gc)    AS co_freq,
+                     count(DISTINCT sname) AS seed_support
+                WHERE seed_support >= $min_support
+                ORDER BY seed_support DESC, co_freq DESC
+                LIMIT $top_k
+                RETURN name
+                """
+
+                params_top = {
+                    "seed_entities": seed_entities,
+                    "seed_concepts": seed_concepts,
+                    "seed_topics": seed_topics,
+                    "top_k": top_k,
+                    "min_support": min_support,
+                }
+                recs = await session.run(query_top, params_top)
+                other_names = [r["name"] async for r in recs]
+                if not other_names:
+                    return []
+
+            # 3) 回捞阶段:先确定“与 seeds 相连的 gc”,再用 other_names 索引命中
+            #    这里用 UNWIND onames + 按标签逐类 MATCH,确保使用 name 索引,
+            #    避免 `o.name IN $list` 走全表扫描/低效 plan。
+            query_expand = """
+            // 先生成候选 gc(包含任一 seed)
+            CALL {
+              WITH $seed_entities AS se
+              UNWIND se AS sname
+              MATCH (:Entity {name:sname})<-[:HAS_ENTITY]-(gc:GraphChunk)
+              RETURN DISTINCT gc
+              UNION
+              WITH $seed_concepts AS sc
+              UNWIND sc AS sname
+              MATCH (:Concept {name:sname})<-[:HAS_CONCEPT]-(gc:GraphChunk)
+              RETURN DISTINCT gc
+              UNION
+              WITH $seed_topics AS st
+              UNWIND st AS sname
+              MATCH (:Topic {name:sname})<-[:HAS_TOPIC]-(gc:GraphChunk)
+              RETURN DISTINCT gc
+            }
+
+            // 用索引命中 other(UNWIND 列表,让等值匹配可走索引)
+            UNWIND $other_names AS oname
+            CALL {
+              WITH oname
+              MATCH (o:Entity {name:oname})  RETURN o
+              UNION
+              WITH oname
+              MATCH (o:Concept {name:oname}) RETURN o
+              UNION
+              WITH oname
+              MATCH (o:Topic {name:oname})   RETURN o
+            }
+
+            MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(o)
+            RETURN DISTINCT gc.milvus_id AS milvus_id
+            LIMIT $limit
+            """
+
+            params_expand = {
+                "seed_entities": seed_entities,
+                "seed_concepts": seed_concepts,
+                "seed_topics": seed_topics,
+                "other_names": other_names,
+                "limit": limit,
+            }
+            recs2 = await session.run(query_expand, params_expand)
+            return [r["milvus_id"] async for r in recs2]
+
+    # ========== 路径 Path ==========
+    async def shortest_path_chunks(
+        self,
+        a_name: str,
+        a_label: str,
+        b_name: str,
+        b_label: str,
+        max_len: int = 4,
+        limit: int = 200,
+    ) -> List[str]:
+        """
+        找到两个要素之间的最短路径,并返回路径上的 chunk
+        """
+        query = f"""
+        MATCH (a:`{a_label}` {{name:$a_name}}), (b:`{b_label}` {{name:$b_name}})
+        CALL {{
+          WITH a,b
+          MATCH p = shortestPath(
+            (a)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC|:BELONGS_TO*..{max_len}]-(b)
+          )
+          RETURN p LIMIT 1
+        }}
+        WITH p
+        UNWIND [n IN nodes(p) WHERE n:GraphChunk | n] AS gc
+        RETURN DISTINCT gc.milvus_id AS milvus_id
+        LIMIT ${limit}
+        """
+        async with self.driver.session(database=self.database) as session:
+            records = await session.run(query, {"a_name": a_name, "b_name": b_name})
+            return [str(r["milvus_id"]) async for r in records]
+
+    # ========== 扩展 Expansion ==========
+    async def expand_candidates(
+        self, seed_ids: List[int], k_per_relation: int = 200, limit: int = 1000
+    ) -> List[Dict[str, Any]]:
+        """
+        基于候选 milvus_id 做 1-hop 扩展(实体/概念/主题),并按权重汇总
+        """
+        query = """
+          CALL {
+          WITH $seed_ids AS seed_ids, $k_per_relation AS k
+        
+          // 同实体
+          MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
+          MATCH (gc0)-[:HAS_ENTITY]->(:Entity)<-[:HAS_ENTITY]-(cand:GraphChunk)
+          WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
+          WITH DISTINCT cand, 1.0 AS w
+          LIMIT k
+          RETURN cand AS gc, w
+        
+          UNION
+        
+          // 同概念
+          WITH seed_ids, k
+          MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
+          MATCH (gc0)-[:HAS_CONCEPT]->(:Concept)<-[:HAS_CONCEPT]-(cand:GraphChunk)
+          WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
+          WITH DISTINCT cand, 0.7 AS w
+          LIMIT k
+          RETURN cand AS gc, w
+        
+          UNION
+        
+          // 同主题
+          WITH seed_ids, k
+          MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
+          MATCH (gc0)-[:HAS_TOPIC]->(:Topic)<-[:HAS_TOPIC]-(cand:GraphChunk)
+          WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
+          WITH DISTINCT cand, 0.6 AS w
+          LIMIT k
+          RETURN cand AS gc, w
+        }
+        WITH gc, sum(w) AS score
+        RETURN
+          gc.milvus_id AS milvus_id,
+          gc.chunk_id  AS chunk_id,   // 若你的属性名不同,请改成实际字段
+          gc.doc_id    AS doc_id,     // 同上
+          score
+        ORDER BY score DESC
+        LIMIT $limit
+        """
+        async with self.driver.session(database=self.database) as session:
+            records = await session.run(
+                query,
+                {
+                    "seed_ids": seed_ids,
+                    "k_per_relation": k_per_relation,
+                    "limit": limit,
+                },
+            )
+            return [
+                {
+                    "milvus_id": r["milvus_id"],
+                    "chunk_id": r["chunk_id"],
+                    "doc_id": r["doc_id"],
+                    "score": r["score"],
+                }
+                async for r in records
+            ]

+ 0 - 28
applications/utils/neo4j/query.py

@@ -1,28 +0,0 @@
-class AsyncNeo4jQuery:
-    def __init__(self, neo4j):
-        self.neo4j = neo4j
-
-    async def close(self):
-        await self.neo4j.close()
-
-    async def get_document_by_id(self, doc_id: str):
-        query = """
-        MATCH (d:Document {doc_id: $doc_id})
-        OPTIONAL MATCH (d)-[:HAS_CHUNK]->(c:Chunk)
-        RETURN d, collect(c) as chunks
-        """
-        async with self.neo4j.session() as session:
-            result = await session.run(query, doc_id=doc_id)
-            return [
-                record.data() for record in await result.consume().records
-            ]  # 注意 result 需要 async 迭代
-
-    async def search_chunks_by_topic(self, topic: str):
-        query = """
-        MATCH (c:Chunk {topic: $topic})
-        OPTIONAL MATCH (c)-[:HAS_ENTITY]->(e:Entity)
-        RETURN c, collect(e.name) as entities
-        """
-        async with self.neo4j.session() as session:
-            result = await session.run(query, topic=topic)
-            return [record.data() async for record in result]

+ 21 - 2
routes/buleprint.py

@@ -111,13 +111,16 @@ async def search():
     sort_by: str = body.get("sort_by")
     milvus_size: int = body.get("milvus", 20)
     limit: int = body.get("limit", 10)
+    path_between_chunks: dict = body.get("path_between_chunks", {})
     if not query_text:
         return jsonify({"error": "error  query_text"})
 
     query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
     resource = get_resource_manager()
     search_engine = HybridSearch(
-        milvus_pool=resource.milvus_client, es_pool=resource.es_client
+        milvus_pool=resource.milvus_client,
+        es_pool=resource.es_client,
+        graph_pool=resource.graph_client,
     )
     try:
         match search_type:
@@ -140,6 +143,20 @@ async def search():
                     milvus_size=milvus_size,
                 )
                 return jsonify(response), 200
+            case "hybrid2":
+                co_fields = {"Entity": filters["entities"][0]}
+                response = await search_engine.hybrid_search_with_graph(
+                    filters=filters,
+                    query_vec=query_vector,
+                    anns_field=anns_field,
+                    search_params=search_params,
+                    es_size=es_size,
+                    sort_by=sort_by,
+                    milvus_size=milvus_size,
+                    co_occurrence_fields=co_fields,
+                    shortest_path_fields=path_between_chunks,
+                )
+                return jsonify(response), 200
             case "strategy":
                 return jsonify({"error": "strategy not implemented"}), 405
             case _:
@@ -301,7 +318,9 @@ async def query_search(
     query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
     resource = get_resource_manager()
     search_engine = HybridSearch(
-        milvus_pool=resource.milvus_client, es_pool=resource.es_client
+        milvus_pool=resource.milvus_client,
+        es_pool=resource.es_client,
+        graph_pool=resource.graph_client,
     )
     try:
         match search_type: