graph_expansion.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. import json
  2. from typing import List, Optional, Dict, Any
  3. class AsyncGraphExpansion:
  4. def __init__(self, driver, database: str = "neo4j"):
  5. self.driver = driver
  6. self.database = database
  7. # ========== 共现 co-occurrence ==========
  8. async def co_occurrence(
  9. self,
  10. seed_name: str,
  11. seed_label: str,
  12. other_names: Optional[List[str]] = None,
  13. limit: int = 500,
  14. ) -> List[str]:
  15. """
  16. 找与种子要素共现的 chunk
  17. :param seed_name: 种子名称
  18. :param seed_label: 种子标签 ('Entity', 'Concept', 'Topic')
  19. :param other_names: 指定要扩展的其它要素名称列表(若为 None,先查共现TopN再扩展)
  20. :param limit: 返回数量上限
  21. """
  22. async with self.driver.session(database=self.database) as session:
  23. if not other_names:
  24. # 先统计高频共现要素
  25. query_top = f"""
  26. MATCH (seed:`{seed_label}` {{name:$seed_name}})
  27. <-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]-(gc:GraphChunk)
  28. MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(other)
  29. WHERE other <> seed
  30. RETURN other.name AS name, count(*) AS co_freq
  31. ORDER BY co_freq DESC
  32. LIMIT 20
  33. """
  34. records = await session.run(query_top, {"seed_name": seed_name})
  35. other_names = [r["name"] async for r in records]
  36. print(other_names)
  37. # 根据共现要素回捞 chunk
  38. query_expand = f"""
  39. MATCH (seed:`{seed_label}` {{name:$seed_name}})
  40. <-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]-(gc:GraphChunk)
  41. MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(o)
  42. WHERE o.name IN $other_names
  43. RETURN DISTINCT gc.milvus_id AS milvus_id
  44. LIMIT $limit
  45. """
  46. records = await session.run(
  47. query_expand,
  48. {"seed_name": seed_name, "other_names": other_names, "limit": limit},
  49. )
  50. return [str(r["milvus_id"]) async for r in records]
  51. # ========== 多种子(每个 seed 自带标签)共现(性能优化版) ==========
  52. async def co_occurrence_multi_mixed_labels(
  53. self,
  54. seeds: List[
  55. Dict[str, str]
  56. ], # [{"name": "...", "label": "Entity|Concept|Topic"}, ...]
  57. other_names: Optional[List[str]] = None,
  58. limit: int = 500,
  59. top_k: int = 20,
  60. min_support: int = 1, # 至少与多少个不同 seed 共现
  61. ) -> List[str]:
  62. if not seeds:
  63. return []
  64. # 1) 标签白名单 + 去重分桶(减少 Cypher 分支中过滤与扫描)
  65. allowed = {"Entity", "Concept", "Topic"}
  66. seed_entities = sorted({s["name"] for s in seeds if s.get("label") == "Entity"})
  67. seed_concepts = sorted(
  68. {s["name"] for s in seeds if s.get("label") == "Concept"}
  69. )
  70. seed_topics = sorted({s["name"] for s in seeds if s.get("label") == "Topic"})
  71. if not (set(x.get("label") for x in seeds) <= allowed):
  72. raise ValueError("Seed label must be one of Entity/Concept/Topic")
  73. async with self.driver.session(database=self.database) as session:
  74. # 2) 统计跨 seeds 的共现 TopN(当未显式提供 other_names 时)
  75. if not other_names:
  76. query_top = """
  77. // ========== 先收敛到与 seeds 共现的 gc + sname ==========
  78. CALL {
  79. WITH $seed_entities AS se
  80. UNWIND se AS sname
  81. MATCH (:Entity {name:sname})<-[:HAS_ENTITY]-(gc:GraphChunk)
  82. RETURN gc, sname
  83. UNION
  84. WITH $seed_concepts AS sc
  85. UNWIND sc AS sname
  86. MATCH (:Concept {name:sname})<-[:HAS_CONCEPT]-(gc:GraphChunk)
  87. RETURN gc, sname
  88. UNION
  89. WITH $seed_topics AS st
  90. UNWIND st AS sname
  91. MATCH (:Topic {name:sname})<-[:HAS_TOPIC]-(gc:GraphChunk)
  92. RETURN gc, sname
  93. }
  94. // ========== 在这些 gc 上一次性发散到 other ==========
  95. MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(other)
  96. // 排除把种子本身当作 other 的情况(基于标签分桶,仍然命中索引)
  97. WHERE NOT (
  98. (other:Entity AND other.name IN $seed_entities) OR
  99. (other:Concept AND other.name IN $seed_concepts) OR
  100. (other:Topic AND other.name IN $seed_topics)
  101. )
  102. WITH other.name AS name, gc, sname
  103. // 跨 gc 汇总频次;跨 sname 汇总支持度(命中的不同 seed 数)
  104. WITH name,
  105. count(DISTINCT gc) AS co_freq,
  106. count(DISTINCT sname) AS seed_support
  107. WHERE seed_support >= $min_support
  108. ORDER BY seed_support DESC, co_freq DESC
  109. LIMIT $top_k
  110. RETURN name
  111. """
  112. params_top = {
  113. "seed_entities": seed_entities,
  114. "seed_concepts": seed_concepts,
  115. "seed_topics": seed_topics,
  116. "top_k": top_k,
  117. "min_support": min_support,
  118. }
  119. recs = await session.run(query_top, params_top)
  120. other_names = [r["name"] async for r in recs]
  121. if not other_names:
  122. return []
  123. # 3) 回捞阶段:先确定“与 seeds 相连的 gc”,再用 other_names 索引命中
  124. # 这里用 UNWIND onames + 按标签逐类 MATCH,确保使用 name 索引,
  125. # 避免 `o.name IN $list` 走全表扫描/低效 plan。
  126. query_expand = """
  127. // 先生成候选 gc(包含任一 seed)
  128. CALL {
  129. WITH $seed_entities AS se
  130. UNWIND se AS sname
  131. MATCH (:Entity {name:sname})<-[:HAS_ENTITY]-(gc:GraphChunk)
  132. RETURN DISTINCT gc
  133. UNION
  134. WITH $seed_concepts AS sc
  135. UNWIND sc AS sname
  136. MATCH (:Concept {name:sname})<-[:HAS_CONCEPT]-(gc:GraphChunk)
  137. RETURN DISTINCT gc
  138. UNION
  139. WITH $seed_topics AS st
  140. UNWIND st AS sname
  141. MATCH (:Topic {name:sname})<-[:HAS_TOPIC]-(gc:GraphChunk)
  142. RETURN DISTINCT gc
  143. }
  144. // 用索引命中 other(UNWIND 列表,让等值匹配可走索引)
  145. UNWIND $other_names AS oname
  146. CALL {
  147. WITH oname
  148. MATCH (o:Entity {name:oname}) RETURN o
  149. UNION
  150. WITH oname
  151. MATCH (o:Concept {name:oname}) RETURN o
  152. UNION
  153. WITH oname
  154. MATCH (o:Topic {name:oname}) RETURN o
  155. }
  156. MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(o)
  157. RETURN DISTINCT gc.milvus_id AS milvus_id
  158. LIMIT $limit
  159. """
  160. params_expand = {
  161. "seed_entities": seed_entities,
  162. "seed_concepts": seed_concepts,
  163. "seed_topics": seed_topics,
  164. "other_names": other_names,
  165. "limit": limit,
  166. }
  167. recs2 = await session.run(query_expand, params_expand)
  168. return [r["milvus_id"] async for r in recs2]
  169. # ========== 路径 Path ==========
  170. async def shortest_path_chunks(
  171. self,
  172. a_name: str,
  173. a_label: str,
  174. b_name: str,
  175. b_label: str,
  176. max_len: int = 4,
  177. limit: int = 200,
  178. ) -> List[str]:
  179. """
  180. 找到两个要素之间的最短路径,并返回路径上的 chunk
  181. """
  182. query = f"""
  183. MATCH (a:`{a_label}` {{name:$a_name}}), (b:`{b_label}` {{name:$b_name}})
  184. CALL {{
  185. WITH a,b
  186. MATCH p = shortestPath(
  187. (a)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC|:BELONGS_TO*..{max_len}]-(b)
  188. )
  189. RETURN p LIMIT 1
  190. }}
  191. WITH p
  192. UNWIND [n IN nodes(p) WHERE n:GraphChunk | n] AS gc
  193. RETURN DISTINCT gc.milvus_id AS milvus_id
  194. LIMIT ${limit}
  195. """
  196. async with self.driver.session(database=self.database) as session:
  197. records = await session.run(query, {"a_name": a_name, "b_name": b_name})
  198. return [str(r["milvus_id"]) async for r in records]
  199. # ========== 扩展 Expansion ==========
  200. async def expand_candidates(
  201. self, seed_ids: List[int], k_per_relation: int = 200, limit: int = 1000
  202. ) -> List[Dict[str, Any]]:
  203. """
  204. 基于候选 milvus_id 做 1-hop 扩展(实体/概念/主题),并按权重汇总
  205. """
  206. query = """
  207. CALL {
  208. WITH $seed_ids AS seed_ids, $k_per_relation AS k
  209. // 同实体
  210. MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
  211. MATCH (gc0)-[:HAS_ENTITY]->(:Entity)<-[:HAS_ENTITY]-(cand:GraphChunk)
  212. WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
  213. WITH DISTINCT cand, 1.0 AS w
  214. LIMIT k
  215. RETURN cand AS gc, w
  216. UNION
  217. // 同概念
  218. WITH seed_ids, k
  219. MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
  220. MATCH (gc0)-[:HAS_CONCEPT]->(:Concept)<-[:HAS_CONCEPT]-(cand:GraphChunk)
  221. WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
  222. WITH DISTINCT cand, 0.7 AS w
  223. LIMIT k
  224. RETURN cand AS gc, w
  225. UNION
  226. // 同主题
  227. WITH seed_ids, k
  228. MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
  229. MATCH (gc0)-[:HAS_TOPIC]->(:Topic)<-[:HAS_TOPIC]-(cand:GraphChunk)
  230. WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
  231. WITH DISTINCT cand, 0.6 AS w
  232. LIMIT k
  233. RETURN cand AS gc, w
  234. }
  235. WITH gc, sum(w) AS score
  236. RETURN
  237. gc.milvus_id AS milvus_id,
  238. gc.chunk_id AS chunk_id, // 若你的属性名不同,请改成实际字段
  239. gc.doc_id AS doc_id, // 同上
  240. score
  241. ORDER BY score DESC
  242. LIMIT $limit
  243. """
  244. async with self.driver.session(database=self.database) as session:
  245. records = await session.run(
  246. query,
  247. {
  248. "seed_ids": seed_ids,
  249. "k_per_relation": k_per_relation,
  250. "limit": limit,
  251. },
  252. )
  253. return [
  254. {
  255. "milvus_id": r["milvus_id"],
  256. "chunk_id": r["chunk_id"],
  257. "doc_id": r["doc_id"],
  258. "score": r["score"],
  259. }
  260. async for r in records
  261. ]