|
@@ -0,0 +1,129 @@
|
|
|
+from typing import List, Dict, Any, Optional
|
|
|
+
|
|
|
+
|
|
|
+class AsyncGraphExpansion:
|
|
|
+ def __init__(self, driver, database: str = "neo4j"):
|
|
|
+ self.driver = driver
|
|
|
+ self.database = database
|
|
|
+
|
|
|
+ # ========== 共现 co-occurrence ==========
|
|
|
+ async def co_occurrence(
|
|
|
+ self,
|
|
|
+ seed_name: str,
|
|
|
+ seed_label: str,
|
|
|
+ other_names: Optional[List[str]] = None,
|
|
|
+ limit: int = 500,
|
|
|
+ ) -> List[str]:
|
|
|
+ """
|
|
|
+ 找与种子要素共现的 chunk
|
|
|
+ :param seed_name: 种子名称
|
|
|
+ :param seed_label: 种子标签 ('Entity', 'Concept', 'Topic')
|
|
|
+ :param other_names: 指定要扩展的其它要素名称列表(若为 None,先查共现TopN再扩展)
|
|
|
+ :param limit: 返回数量上限
|
|
|
+ """
|
|
|
+ async with self.driver.session(database=self.database) as session:
|
|
|
+ if not other_names:
|
|
|
+ # 先统计高频共现要素
|
|
|
+ query_top = f"""
|
|
|
+ MATCH (seed:`{seed_label}` {{name:$seed_name}})
|
|
|
+ <-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]-(gc:GraphChunk)
|
|
|
+ MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(other)
|
|
|
+ WHERE other <> seed
|
|
|
+ RETURN other.name AS name, count(*) AS co_freq
|
|
|
+ ORDER BY co_freq DESC
|
|
|
+ LIMIT 20
|
|
|
+ """
|
|
|
+ records = await session.run(query_top, {"seed_name": seed_name})
|
|
|
+ other_names = [r["name"] async for r in records]
|
|
|
+
|
|
|
+ # 根据共现要素回捞 chunk
|
|
|
+ query_expand = f"""
|
|
|
+ MATCH (seed:`{seed_label}` {{name:$seed_name}})
|
|
|
+ <-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]-(gc:GraphChunk)
|
|
|
+ MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(o)
|
|
|
+ WHERE o.name IN $other_names
|
|
|
+ RETURN DISTINCT gc.milvus_id AS milvus_id
|
|
|
+ LIMIT $limit
|
|
|
+ """
|
|
|
+ records = await session.run(
|
|
|
+ query_expand,
|
|
|
+ {"seed_name": seed_name, "other_names": other_names, "limit": limit},
|
|
|
+ )
|
|
|
+ return [r["milvus_id"] async for r in records]
|
|
|
+
|
|
|
+ # ========== 路径 Path ==========
|
|
|
+ async def shortest_path_chunks(
|
|
|
+ self,
|
|
|
+ a_name: str,
|
|
|
+ a_label: str,
|
|
|
+ b_name: str,
|
|
|
+ b_label: str,
|
|
|
+ max_len: int = 4,
|
|
|
+ limit: int = 200,
|
|
|
+ ) -> List[str]:
|
|
|
+ """
|
|
|
+ 找到两个要素之间的最短路径,并返回路径上的 chunk
|
|
|
+ """
|
|
|
+ query = f"""
|
|
|
+ MATCH (a:`{a_label}` {{name:$a_name}}), (b:`{b_label}` {{name:$b_name}})
|
|
|
+ CALL {{
|
|
|
+ WITH a,b
|
|
|
+ MATCH p = shortestPath(
|
|
|
+ (a)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC|:BELONGS_TO*..{max_len}]-(b)
|
|
|
+ )
|
|
|
+ RETURN p LIMIT 1
|
|
|
+ }}
|
|
|
+ WITH p
|
|
|
+ UNWIND [n IN nodes(p) WHERE n:GraphChunk | n] AS gc
|
|
|
+ RETURN DISTINCT gc.milvus_id AS milvus_id
|
|
|
+ LIMIT $limit
|
|
|
+ """
|
|
|
+ async with self.driver.session(database=self.database) as session:
|
|
|
+ records = await session.run(query, {"a_name": a_name, "b_name": b_name})
|
|
|
+ return [r["milvus_id"] async for r in records]
|
|
|
+
|
|
|
+ # ========== 扩展 Expansion ==========
|
|
|
+ async def expand_candidates(
|
|
|
+ self, seed_ids: List[str], k_per_relation: int = 200, limit: int = 1000
|
|
|
+ ) -> List[str]:
|
|
|
+ """
|
|
|
+ 基于候选 milvus_id 做 1-hop 扩展(实体/概念/主题),并按权重汇总
|
|
|
+ """
|
|
|
+ query = """
|
|
|
+ MATCH (gc:GraphChunk) WHERE gc.milvus_id IN $seed_ids
|
|
|
+ // 同实体
|
|
|
+ MATCH (gc)-[:HAS_ENTITY]->(e)<-[:HAS_ENTITY]-(gc2:GraphChunk)
|
|
|
+ WHERE gc2 <> gc
|
|
|
+ WITH DISTINCT gc2, 1.0 AS w
|
|
|
+ LIMIT $k_per_relation
|
|
|
+
|
|
|
+ UNION
|
|
|
+
|
|
|
+ MATCH (gc:GraphChunk) WHERE gc.milvus_id IN $seed_ids
|
|
|
+ MATCH (gc)-[:HAS_CONCEPT]->(c)<-[:HAS_CONCEPT]-(gc3:GraphChunk)
|
|
|
+ WHERE gc3 <> gc
|
|
|
+ WITH DISTINCT gc3, 0.7 AS w
|
|
|
+ LIMIT $k_per_relation
|
|
|
+
|
|
|
+ UNION
|
|
|
+
|
|
|
+ MATCH (gc:GraphChunk) WHERE gc.milvus_id IN $seed_ids
|
|
|
+ MATCH (gc)-[:HAS_TOPIC]->(t)<-[:HAS_TOPIC]-(gc4:GraphChunk)
|
|
|
+ WHERE gc4 <> gc
|
|
|
+ WITH DISTINCT gc4, 0.6 AS w
|
|
|
+ LIMIT $k_per_relation
|
|
|
+
|
|
|
+ RETURN DISTINCT gc2.milvus_id AS milvus_id, sum(w) AS score
|
|
|
+ ORDER BY score DESC
|
|
|
+ LIMIT $limit
|
|
|
+ """
|
|
|
+ async with self.driver.session(database=self.database) as session:
|
|
|
+ records = await session.run(
|
|
|
+ query,
|
|
|
+ {
|
|
|
+ "seed_ids": seed_ids,
|
|
|
+ "k_per_relation": k_per_relation,
|
|
|
+ "limit": limit,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ return [r["milvus_id"] async for r in records]
|