from typing import List, Dict, Any, Optional class AsyncGraphExpansion: def __init__(self, driver, database: str = "neo4j"): self.driver = driver self.database = database # ========== 共现 co-occurrence ========== async def co_occurrence( self, seed_name: str, seed_label: str, other_names: Optional[List[str]] = None, limit: int = 500, ) -> List[str]: """ 找与种子要素共现的 chunk :param seed_name: 种子名称 :param seed_label: 种子标签 ('Entity', 'Concept', 'Topic') :param other_names: 指定要扩展的其它要素名称列表(若为 None,先查共现TopN再扩展) :param limit: 返回数量上限 """ async with self.driver.session(database=self.database) as session: if not other_names: # 先统计高频共现要素 query_top = f""" MATCH (seed:`{seed_label}` {{name:$seed_name}}) <-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]-(gc:GraphChunk) MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(other) WHERE other <> seed RETURN other.name AS name, count(*) AS co_freq ORDER BY co_freq DESC LIMIT 20 """ records = await session.run(query_top, {"seed_name": seed_name}) other_names = [r["name"] async for r in records] # 根据共现要素回捞 chunk query_expand = f""" MATCH (seed:`{seed_label}` {{name:$seed_name}}) <-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]-(gc:GraphChunk) MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(o) WHERE o.name IN $other_names RETURN DISTINCT gc.milvus_id AS milvus_id LIMIT $limit """ records = await session.run( query_expand, {"seed_name": seed_name, "other_names": other_names, "limit": limit}, ) return [r["milvus_id"] async for r in records] # ========== 路径 Path ========== async def shortest_path_chunks( self, a_name: str, a_label: str, b_name: str, b_label: str, max_len: int = 4, limit: int = 200, ) -> List[str]: """ 找到两个要素之间的最短路径,并返回路径上的 chunk """ query = f""" MATCH (a:`{a_label}` {{name:$a_name}}), (b:`{b_label}` {{name:$b_name}}) CALL {{ WITH a,b MATCH p = shortestPath( (a)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC|:BELONGS_TO*..{max_len}]-(b) ) RETURN p LIMIT 1 }} WITH p UNWIND [n IN nodes(p) WHERE n:GraphChunk | n] AS gc RETURN DISTINCT gc.milvus_id AS milvus_id LIMIT $limit """ async with self.driver.session(database=self.database) as session: records = await session.run(query, {"a_name": a_name, "b_name": b_name}) return [r["milvus_id"] async for r in records] # ========== 扩展 Expansion ========== async def expand_candidates( self, seed_ids: List[str], k_per_relation: int = 200, limit: int = 1000 ) -> List[str]: """ 基于候选 milvus_id 做 1-hop 扩展(实体/概念/主题),并按权重汇总 """ query = """ MATCH (gc:GraphChunk) WHERE gc.milvus_id IN $seed_ids // 同实体 MATCH (gc)-[:HAS_ENTITY]->(e)<-[:HAS_ENTITY]-(gc2:GraphChunk) WHERE gc2 <> gc WITH DISTINCT gc2, 1.0 AS w LIMIT $k_per_relation UNION MATCH (gc:GraphChunk) WHERE gc.milvus_id IN $seed_ids MATCH (gc)-[:HAS_CONCEPT]->(c)<-[:HAS_CONCEPT]-(gc3:GraphChunk) WHERE gc3 <> gc WITH DISTINCT gc3, 0.7 AS w LIMIT $k_per_relation UNION MATCH (gc:GraphChunk) WHERE gc.milvus_id IN $seed_ids MATCH (gc)-[:HAS_TOPIC]->(t)<-[:HAS_TOPIC]-(gc4:GraphChunk) WHERE gc4 <> gc WITH DISTINCT gc4, 0.6 AS w LIMIT $k_per_relation RETURN DISTINCT gc2.milvus_id AS milvus_id, sum(w) AS score ORDER BY score DESC LIMIT $limit """ async with self.driver.session(database=self.database) as session: records = await session.run( query, { "seed_ids": seed_ids, "k_per_relation": k_per_relation, "limit": limit, }, ) return [r["milvus_id"] async for r in records]