graph_expansion.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. from typing import List, Dict, Any, Optional
  2. class AsyncGraphExpansion:
  3. def __init__(self, driver, database: str = "neo4j"):
  4. self.driver = driver
  5. self.database = database
  6. # ========== 共现 co-occurrence ==========
  7. async def co_occurrence(
  8. self,
  9. seed_name: str,
  10. seed_label: str,
  11. other_names: Optional[List[str]] = None,
  12. limit: int = 500,
  13. ) -> List[str]:
  14. """
  15. 找与种子要素共现的 chunk
  16. :param seed_name: 种子名称
  17. :param seed_label: 种子标签 ('Entity', 'Concept', 'Topic')
  18. :param other_names: 指定要扩展的其它要素名称列表(若为 None,先查共现TopN再扩展)
  19. :param limit: 返回数量上限
  20. """
  21. async with self.driver.session(database=self.database) as session:
  22. if not other_names:
  23. # 先统计高频共现要素
  24. query_top = f"""
  25. MATCH (seed:`{seed_label}` {{name:$seed_name}})
  26. <-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]-(gc:GraphChunk)
  27. MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(other)
  28. WHERE other <> seed
  29. RETURN other.name AS name, count(*) AS co_freq
  30. ORDER BY co_freq DESC
  31. LIMIT 20
  32. """
  33. records = await session.run(query_top, {"seed_name": seed_name})
  34. other_names = [r["name"] async for r in records]
  35. # 根据共现要素回捞 chunk
  36. query_expand = f"""
  37. MATCH (seed:`{seed_label}` {{name:$seed_name}})
  38. <-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]-(gc:GraphChunk)
  39. MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(o)
  40. WHERE o.name IN $other_names
  41. RETURN DISTINCT gc.milvus_id AS milvus_id
  42. LIMIT $limit
  43. """
  44. records = await session.run(
  45. query_expand,
  46. {"seed_name": seed_name, "other_names": other_names, "limit": limit},
  47. )
  48. return [r["milvus_id"] async for r in records]
  49. # ========== 路径 Path ==========
  50. async def shortest_path_chunks(
  51. self,
  52. a_name: str,
  53. a_label: str,
  54. b_name: str,
  55. b_label: str,
  56. max_len: int = 4,
  57. limit: int = 200,
  58. ) -> List[str]:
  59. """
  60. 找到两个要素之间的最短路径,并返回路径上的 chunk
  61. """
  62. query = f"""
  63. MATCH (a:`{a_label}` {{name:$a_name}}), (b:`{b_label}` {{name:$b_name}})
  64. CALL {{
  65. WITH a,b
  66. MATCH p = shortestPath(
  67. (a)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC|:BELONGS_TO*..{max_len}]-(b)
  68. )
  69. RETURN p LIMIT 1
  70. }}
  71. WITH p
  72. UNWIND [n IN nodes(p) WHERE n:GraphChunk | n] AS gc
  73. RETURN DISTINCT gc.milvus_id AS milvus_id
  74. LIMIT $limit
  75. """
  76. async with self.driver.session(database=self.database) as session:
  77. records = await session.run(query, {"a_name": a_name, "b_name": b_name})
  78. return [r["milvus_id"] async for r in records]
  79. # ========== 扩展 Expansion ==========
  80. async def expand_candidates(
  81. self, seed_ids: List[str], k_per_relation: int = 200, limit: int = 1000
  82. ) -> List[str]:
  83. """
  84. 基于候选 milvus_id 做 1-hop 扩展(实体/概念/主题),并按权重汇总
  85. """
  86. query = """
  87. MATCH (gc:GraphChunk) WHERE gc.milvus_id IN $seed_ids
  88. // 同实体
  89. MATCH (gc)-[:HAS_ENTITY]->(e)<-[:HAS_ENTITY]-(gc2:GraphChunk)
  90. WHERE gc2 <> gc
  91. WITH DISTINCT gc2, 1.0 AS w
  92. LIMIT $k_per_relation
  93. UNION
  94. MATCH (gc:GraphChunk) WHERE gc.milvus_id IN $seed_ids
  95. MATCH (gc)-[:HAS_CONCEPT]->(c)<-[:HAS_CONCEPT]-(gc3:GraphChunk)
  96. WHERE gc3 <> gc
  97. WITH DISTINCT gc3, 0.7 AS w
  98. LIMIT $k_per_relation
  99. UNION
  100. MATCH (gc:GraphChunk) WHERE gc.milvus_id IN $seed_ids
  101. MATCH (gc)-[:HAS_TOPIC]->(t)<-[:HAS_TOPIC]-(gc4:GraphChunk)
  102. WHERE gc4 <> gc
  103. WITH DISTINCT gc4, 0.6 AS w
  104. LIMIT $k_per_relation
  105. RETURN DISTINCT gc2.milvus_id AS milvus_id, sum(w) AS score
  106. ORDER BY score DESC
  107. LIMIT $limit
  108. """
  109. async with self.driver.session(database=self.database) as session:
  110. records = await session.run(
  111. query,
  112. {
  113. "seed_ids": seed_ids,
  114. "k_per_relation": k_per_relation,
  115. "limit": limit,
  116. },
  117. )
  118. return [r["milvus_id"] async for r in records]