Quellcode durchsuchen

开发 graph_expansion

luojunhui vor 1 Woche
Ursprung
Commit
c3387b2575
2 geänderte Dateien mit 129 neuen und 28 gelöschten Zeilen
  1. 129 0
      applications/utils/neo4j/graph_expansion.py
  2. 0 28
      applications/utils/neo4j/query.py

+ 129 - 0
applications/utils/neo4j/graph_expansion.py

@@ -0,0 +1,129 @@
+from typing import List, Dict, Any, Optional
+
+
+class AsyncGraphExpansion:
+    def __init__(self, driver, database: str = "neo4j"):
+        self.driver = driver
+        self.database = database
+
+    # ========== 共现 co-occurrence ==========
+    async def co_occurrence(
+        self,
+        seed_name: str,
+        seed_label: str,
+        other_names: Optional[List[str]] = None,
+        limit: int = 500,
+    ) -> List[str]:
+        """
+        找与种子要素共现的 chunk
+        :param seed_name: 种子名称
+        :param seed_label: 种子标签 ('Entity', 'Concept', 'Topic')
+        :param other_names: 指定要扩展的其它要素名称列表(若为 None,先查共现TopN再扩展)
+        :param limit: 返回数量上限
+        """
+        async with self.driver.session(database=self.database) as session:
+            if not other_names:
+                # 先统计高频共现要素
+                query_top = f"""
+                MATCH (seed:`{seed_label}` {{name:$seed_name}})
+                      <-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]-(gc:GraphChunk)
+                MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(other)
+                WHERE other <> seed
+                RETURN other.name AS name, count(*) AS co_freq
+                ORDER BY co_freq DESC
+                LIMIT 20
+                """
+                records = await session.run(query_top, {"seed_name": seed_name})
+                other_names = [r["name"] async for r in records]
+
+            # 根据共现要素回捞 chunk
+            query_expand = f"""
+            MATCH (seed:`{seed_label}` {{name:$seed_name}})
+                  <-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]-(gc:GraphChunk)
+            MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(o)
+            WHERE o.name IN $other_names
+            RETURN DISTINCT gc.milvus_id AS milvus_id
+            LIMIT $limit
+            """
+            records = await session.run(
+                query_expand,
+                {"seed_name": seed_name, "other_names": other_names, "limit": limit},
+            )
+            return [r["milvus_id"] async for r in records]
+
+    # ========== 路径 Path ==========
+    async def shortest_path_chunks(
+        self,
+        a_name: str,
+        a_label: str,
+        b_name: str,
+        b_label: str,
+        max_len: int = 4,
+        limit: int = 200,
+    ) -> List[str]:
+        """
+        找到两个要素之间的最短路径,并返回路径上的 chunk
+        """
+        query = f"""
+        MATCH (a:`{a_label}` {{name:$a_name}}), (b:`{b_label}` {{name:$b_name}})
+        CALL {{
+          WITH a,b
+          MATCH p = shortestPath(
+            (a)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC|:BELONGS_TO*..{max_len}]-(b)
+          )
+          RETURN p LIMIT 1
+        }}
+        WITH p
+        UNWIND [n IN nodes(p) WHERE n:GraphChunk | n] AS gc
+        RETURN DISTINCT gc.milvus_id AS milvus_id
+        LIMIT $limit
+        """
+        async with self.driver.session(database=self.database) as session:
+            records = await session.run(query, {"a_name": a_name, "b_name": b_name})
+            return [r["milvus_id"] async for r in records]
+
+    # ========== 扩展 Expansion ==========
+    async def expand_candidates(
+        self, seed_ids: List[str], k_per_relation: int = 200, limit: int = 1000
+    ) -> List[str]:
+        """
+        基于候选 milvus_id 做 1-hop 扩展(实体/概念/主题),并按权重汇总
+        """
+        query = """
+        MATCH (gc:GraphChunk) WHERE gc.milvus_id IN $seed_ids
+        // 同实体
+        MATCH (gc)-[:HAS_ENTITY]->(e)<-[:HAS_ENTITY]-(gc2:GraphChunk)
+        WHERE gc2 <> gc
+        WITH DISTINCT gc2, 1.0 AS w
+        LIMIT $k_per_relation
+
+        UNION
+
+        MATCH (gc:GraphChunk) WHERE gc.milvus_id IN $seed_ids
+        MATCH (gc)-[:HAS_CONCEPT]->(c)<-[:HAS_CONCEPT]-(gc3:GraphChunk)
+        WHERE gc3 <> gc
+        WITH DISTINCT gc3, 0.7 AS w
+        LIMIT $k_per_relation
+
+        UNION
+
+        MATCH (gc:GraphChunk) WHERE gc.milvus_id IN $seed_ids
+        MATCH (gc)-[:HAS_TOPIC]->(t)<-[:HAS_TOPIC]-(gc4:GraphChunk)
+        WHERE gc4 <> gc
+        WITH DISTINCT gc4, 0.6 AS w
+        LIMIT $k_per_relation
+
+        RETURN DISTINCT gc2.milvus_id AS milvus_id, sum(w) AS score
+        ORDER BY score DESC
+        LIMIT $limit
+        """
+        async with self.driver.session(database=self.database) as session:
+            records = await session.run(
+                query,
+                {
+                    "seed_ids": seed_ids,
+                    "k_per_relation": k_per_relation,
+                    "limit": limit,
+                },
+            )
+            return [r["milvus_id"] async for r in records]

+ 0 - 28
applications/utils/neo4j/query.py

@@ -1,28 +0,0 @@
-class AsyncNeo4jQuery:
-    def __init__(self, neo4j):
-        self.neo4j = neo4j
-
-    async def close(self):
-        await self.neo4j.close()
-
-    async def get_document_by_id(self, doc_id: str):
-        query = """
-        MATCH (d:Document {doc_id: $doc_id})
-        OPTIONAL MATCH (d)-[:HAS_CHUNK]->(c:Chunk)
-        RETURN d, collect(c) as chunks
-        """
-        async with self.neo4j.session() as session:
-            result = await session.run(query, doc_id=doc_id)
-            return [
-                record.data() for record in await result.consume().records
-            ]  # 注意 result 需要 async 迭代
-
-    async def search_chunks_by_topic(self, topic: str):
-        query = """
-        MATCH (c:Chunk {topic: $topic})
-        OPTIONAL MATCH (c)-[:HAS_ENTITY]->(e:Entity)
-        RETURN c, collect(e.name) as entities
-        """
-        async with self.neo4j.session() as session:
-            result = await session.run(query, topic=topic)
-            return [record.data() async for record in result]