123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- from typing import List, Dict, Any, Optional
- class AsyncGraphExpansion:
- def __init__(self, driver, database: str = "neo4j"):
- self.driver = driver
- self.database = database
- # ========== 共现 co-occurrence ==========
- async def co_occurrence(
- self,
- seed_name: str,
- seed_label: str,
- other_names: Optional[List[str]] = None,
- limit: int = 500,
- ) -> List[str]:
- """
- 找与种子要素共现的 chunk
- :param seed_name: 种子名称
- :param seed_label: 种子标签 ('Entity', 'Concept', 'Topic')
- :param other_names: 指定要扩展的其它要素名称列表(若为 None,先查共现TopN再扩展)
- :param limit: 返回数量上限
- """
- async with self.driver.session(database=self.database) as session:
- if not other_names:
- # 先统计高频共现要素
- query_top = f"""
- MATCH (seed:`{seed_label}` {{name:$seed_name}})
- <-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]-(gc:GraphChunk)
- MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(other)
- WHERE other <> seed
- RETURN other.name AS name, count(*) AS co_freq
- ORDER BY co_freq DESC
- LIMIT 20
- """
- records = await session.run(query_top, {"seed_name": seed_name})
- other_names = [r["name"] async for r in records]
- # 根据共现要素回捞 chunk
- query_expand = f"""
- MATCH (seed:`{seed_label}` {{name:$seed_name}})
- <-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]-(gc:GraphChunk)
- MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(o)
- WHERE o.name IN $other_names
- RETURN DISTINCT gc.milvus_id AS milvus_id
- LIMIT $limit
- """
- records = await session.run(
- query_expand,
- {"seed_name": seed_name, "other_names": other_names, "limit": limit},
- )
- return [r["milvus_id"] async for r in records]
- # ========== 路径 Path ==========
- async def shortest_path_chunks(
- self,
- a_name: str,
- a_label: str,
- b_name: str,
- b_label: str,
- max_len: int = 4,
- limit: int = 200,
- ) -> List[str]:
- """
- 找到两个要素之间的最短路径,并返回路径上的 chunk
- """
- query = f"""
- MATCH (a:`{a_label}` {{name:$a_name}}), (b:`{b_label}` {{name:$b_name}})
- CALL {{
- WITH a,b
- MATCH p = shortestPath(
- (a)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC|:BELONGS_TO*..{max_len}]-(b)
- )
- RETURN p LIMIT 1
- }}
- WITH p
- UNWIND [n IN nodes(p) WHERE n:GraphChunk | n] AS gc
- RETURN DISTINCT gc.milvus_id AS milvus_id
- LIMIT $limit
- """
- async with self.driver.session(database=self.database) as session:
- records = await session.run(query, {"a_name": a_name, "b_name": b_name})
- return [r["milvus_id"] async for r in records]
- # ========== 扩展 Expansion ==========
- async def expand_candidates(
- self, seed_ids: List[str], k_per_relation: int = 200, limit: int = 1000
- ) -> List[str]:
- """
- 基于候选 milvus_id 做 1-hop 扩展(实体/概念/主题),并按权重汇总
- """
- query = """
- MATCH (gc:GraphChunk) WHERE gc.milvus_id IN $seed_ids
- // 同实体
- MATCH (gc)-[:HAS_ENTITY]->(e)<-[:HAS_ENTITY]-(gc2:GraphChunk)
- WHERE gc2 <> gc
- WITH DISTINCT gc2, 1.0 AS w
- LIMIT $k_per_relation
- UNION
- MATCH (gc:GraphChunk) WHERE gc.milvus_id IN $seed_ids
- MATCH (gc)-[:HAS_CONCEPT]->(c)<-[:HAS_CONCEPT]-(gc3:GraphChunk)
- WHERE gc3 <> gc
- WITH DISTINCT gc3, 0.7 AS w
- LIMIT $k_per_relation
- UNION
- MATCH (gc:GraphChunk) WHERE gc.milvus_id IN $seed_ids
- MATCH (gc)-[:HAS_TOPIC]->(t)<-[:HAS_TOPIC]-(gc4:GraphChunk)
- WHERE gc4 <> gc
- WITH DISTINCT gc4, 0.6 AS w
- LIMIT $k_per_relation
- RETURN DISTINCT gc2.milvus_id AS milvus_id, sum(w) AS score
- ORDER BY score DESC
- LIMIT $limit
- """
- async with self.driver.session(database=self.database) as session:
- records = await session.run(
- query,
- {
- "seed_ids": seed_ids,
- "k_per_relation": k_per_relation,
- "limit": limit,
- },
- )
- return [r["milvus_id"] async for r in records]
|