graph_expansion.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. import json
  2. from typing import List, Optional, Dict, Any, Iterable
  3. class AsyncGraphExpansion:
  4. _REL_ANY = ":HAS_ENTITY|HAS_CONCEPT|HAS_TOPIC"
  5. def __init__(self, driver, database: str = "neo4j"):
  6. self.driver = driver
  7. self.database = database
  8. # ========== 共现 co-occurrence ==========
  9. async def co_occurrence(
  10. self,
  11. seed_name: str,
  12. seed_label: str,
  13. other_names: Optional[List[str]] = None,
  14. limit: int = 500,
  15. ) -> List[str]:
  16. """
  17. 找与种子要素共现的 chunk
  18. """
  19. REL_ANY = (
  20. ":HAS_ENTITY|HAS_CONCEPT|HAS_TOPIC" # 非弃用:只在第一个关系类型前加冒号
  21. )
  22. async with self.driver.session(database=self.database) as session:
  23. if not other_names:
  24. # 一次往返:显式作用域子查询 + 非弃用关系“或”写法
  25. cypher = f"""
  26. MATCH (seed:`{seed_label}` {{name:$seed_name}})
  27. CALL (seed) {{
  28. MATCH (seed)<-[{REL_ANY}]-(gc:GraphChunk)
  29. MATCH (gc)-[{REL_ANY}]->(other)
  30. WHERE other <> seed
  31. WITH other.name AS name, count(*) AS co_freq
  32. ORDER BY co_freq DESC
  33. RETURN collect(name)[..20] AS sel_names
  34. }}
  35. WITH seed, sel_names
  36. WHERE size(sel_names) > 0
  37. MATCH (seed)<-[{REL_ANY}]-(gc2:GraphChunk)
  38. MATCH (gc2)-[{REL_ANY}]->(o2)
  39. WHERE o2.name IN sel_names
  40. RETURN DISTINCT gc2.milvus_id AS milvus_id
  41. LIMIT $limit
  42. """
  43. records = await session.run(
  44. cypher, {"seed_name": seed_name, "limit": limit}
  45. )
  46. return [str(r["milvus_id"]) async for r in records]
  47. # 已提供 other_names:直接扩展(同样使用非弃用关系“或”写法)
  48. cypher = f"""
  49. MATCH (seed:`{seed_label}` {{name:$seed_name}})
  50. <-[{REL_ANY}]-(gc:GraphChunk)
  51. MATCH (gc)-[{REL_ANY}]->(o)
  52. WHERE o.name IN $other_names
  53. RETURN DISTINCT gc.milvus_id AS milvus_id
  54. LIMIT $limit
  55. """
  56. records = await session.run(
  57. cypher,
  58. {"seed_name": seed_name, "other_names": other_names, "limit": limit},
  59. )
  60. return [str(r["milvus_id"]) async for r in records]
  61. # ========== 多种子(每个 seed 自带标签)共现(性能优化版) ==========
  62. async def co_occurrence_multi_mixed_labels(
  63. self,
  64. seeds: List[
  65. Dict[str, str]
  66. ], # [{"name": "...", "label": "Entity|Concept|Topic"}, ...]
  67. other_names: Optional[List[str]] = None,
  68. limit: int = 500,
  69. top_k: int = 20,
  70. min_support: int = 1, # 至少与多少个不同 seed 共现
  71. ) -> List[str]:
  72. if not seeds:
  73. return []
  74. # 1) 标签白名单 + 去重分桶(减少 Cypher 分支中过滤与扫描)
  75. allowed = {"Entity", "Concept", "Topic"}
  76. seed_entities = sorted({s["name"] for s in seeds if s.get("label") == "Entity"})
  77. seed_concepts = sorted(
  78. {s["name"] for s in seeds if s.get("label") == "Concept"}
  79. )
  80. seed_topics = sorted({s["name"] for s in seeds if s.get("label") == "Topic"})
  81. if not (set(x.get("label") for x in seeds) <= allowed):
  82. raise ValueError("Seed label must be one of Entity/Concept/Topic")
  83. async with self.driver.session(database=self.database) as session:
  84. # 2) 统计跨 seeds 的共现 TopN(当未显式提供 other_names 时)
  85. if not other_names:
  86. query_top = """
  87. // ========== 先收敛到与 seeds 共现的 gc + sname ==========
  88. CALL {
  89. WITH $seed_entities AS se
  90. UNWIND se AS sname
  91. MATCH (:Entity {name:sname})<-[:HAS_ENTITY]-(gc:GraphChunk)
  92. RETURN gc, sname
  93. UNION
  94. WITH $seed_concepts AS sc
  95. UNWIND sc AS sname
  96. MATCH (:Concept {name:sname})<-[:HAS_CONCEPT]-(gc:GraphChunk)
  97. RETURN gc, sname
  98. UNION
  99. WITH $seed_topics AS st
  100. UNWIND st AS sname
  101. MATCH (:Topic {name:sname})<-[:HAS_TOPIC]-(gc:GraphChunk)
  102. RETURN gc, sname
  103. }
  104. // ========== 在这些 gc 上一次性发散到 other ==========
  105. MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(other)
  106. // 排除把种子本身当作 other 的情况(基于标签分桶,仍然命中索引)
  107. WHERE NOT (
  108. (other:Entity AND other.name IN $seed_entities) OR
  109. (other:Concept AND other.name IN $seed_concepts) OR
  110. (other:Topic AND other.name IN $seed_topics)
  111. )
  112. WITH other.name AS name, gc, sname
  113. // 跨 gc 汇总频次;跨 sname 汇总支持度(命中的不同 seed 数)
  114. WITH name,
  115. count(DISTINCT gc) AS co_freq,
  116. count(DISTINCT sname) AS seed_support
  117. WHERE seed_support >= $min_support
  118. ORDER BY seed_support DESC, co_freq DESC
  119. LIMIT $top_k
  120. RETURN name
  121. """
  122. params_top = {
  123. "seed_entities": seed_entities,
  124. "seed_concepts": seed_concepts,
  125. "seed_topics": seed_topics,
  126. "top_k": top_k,
  127. "min_support": min_support,
  128. }
  129. recs = await session.run(query_top, params_top)
  130. other_names = [r["name"] async for r in recs]
  131. if not other_names:
  132. return []
  133. # 3) 回捞阶段:先确定“与 seeds 相连的 gc”,再用 other_names 索引命中
  134. # 这里用 UNWIND onames + 按标签逐类 MATCH,确保使用 name 索引,
  135. # 避免 `o.name IN $list` 走全表扫描/低效 plan。
  136. query_expand = """
  137. // 先生成候选 gc(包含任一 seed)
  138. CALL {
  139. WITH $seed_entities AS se
  140. UNWIND se AS sname
  141. MATCH (:Entity {name:sname})<-[:HAS_ENTITY]-(gc:GraphChunk)
  142. RETURN DISTINCT gc
  143. UNION
  144. WITH $seed_concepts AS sc
  145. UNWIND sc AS sname
  146. MATCH (:Concept {name:sname})<-[:HAS_CONCEPT]-(gc:GraphChunk)
  147. RETURN DISTINCT gc
  148. UNION
  149. WITH $seed_topics AS st
  150. UNWIND st AS sname
  151. MATCH (:Topic {name:sname})<-[:HAS_TOPIC]-(gc:GraphChunk)
  152. RETURN DISTINCT gc
  153. }
  154. // 用索引命中 other(UNWIND 列表,让等值匹配可走索引)
  155. UNWIND $other_names AS oname
  156. CALL {
  157. WITH oname
  158. MATCH (o:Entity {name:oname}) RETURN o
  159. UNION
  160. WITH oname
  161. MATCH (o:Concept {name:oname}) RETURN o
  162. UNION
  163. WITH oname
  164. MATCH (o:Topic {name:oname}) RETURN o
  165. }
  166. MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(o)
  167. RETURN DISTINCT gc.milvus_id AS milvus_id
  168. LIMIT $limit
  169. """
  170. params_expand = {
  171. "seed_entities": seed_entities,
  172. "seed_concepts": seed_concepts,
  173. "seed_topics": seed_topics,
  174. "other_names": other_names,
  175. "limit": limit,
  176. }
  177. recs2 = await session.run(query_expand, params_expand)
  178. return [r["milvus_id"] async for r in recs2]
  179. # ========== 路径 Path ==========
  180. async def shortest_path_chunks(
  181. self,
  182. a_name: str,
  183. a_label: str,
  184. b_name: str,
  185. b_label: str,
  186. max_len: int = 4,
  187. limit: int = 200,
  188. ) -> List[str]:
  189. """
  190. 找到两个要素之间的最短路径,并返回路径上的 chunk
  191. """
  192. query = f"""
  193. MATCH (a:`{a_label}` {{name:$a_name}}), (b:`{b_label}` {{name:$b_name}})
  194. CALL {{
  195. WITH a,b
  196. MATCH p = shortestPath(
  197. (a)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC|:BELONGS_TO*..{max_len}]-(b)
  198. )
  199. RETURN p LIMIT 1
  200. }}
  201. WITH p
  202. UNWIND [n IN nodes(p) WHERE n:GraphChunk | n] AS gc
  203. RETURN DISTINCT gc.milvus_id AS milvus_id
  204. LIMIT ${limit}
  205. """
  206. async with self.driver.session(database=self.database) as session:
  207. records = await session.run(query, {"a_name": a_name, "b_name": b_name})
  208. return [str(r["milvus_id"]) async for r in records]
  209. # ========== 扩展 Expansion ==========
  210. async def expand_candidates(
  211. self, seed_ids: List[int], k_per_relation: int = 200, limit: int = 1000
  212. ) -> List[Dict[str, Any]]:
  213. """
  214. 基于候选 milvus_id 做 1-hop 扩展(实体/概念/主题),并按权重汇总
  215. """
  216. query = """
  217. CALL {
  218. WITH $seed_ids AS seed_ids, $k_per_relation AS k
  219. // 同实体
  220. MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
  221. MATCH (gc0)-[:HAS_ENTITY]->(:Entity)<-[:HAS_ENTITY]-(cand:GraphChunk)
  222. WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
  223. WITH DISTINCT cand, 1.0 AS w
  224. LIMIT k
  225. RETURN cand AS gc, w
  226. UNION
  227. // 同概念
  228. WITH seed_ids, k
  229. MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
  230. MATCH (gc0)-[:HAS_CONCEPT]->(:Concept)<-[:HAS_CONCEPT]-(cand:GraphChunk)
  231. WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
  232. WITH DISTINCT cand, 0.7 AS w
  233. LIMIT k
  234. RETURN cand AS gc, w
  235. UNION
  236. // 同主题
  237. WITH seed_ids, k
  238. MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
  239. MATCH (gc0)-[:HAS_TOPIC]->(:Topic)<-[:HAS_TOPIC]-(cand:GraphChunk)
  240. WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
  241. WITH DISTINCT cand, 0.6 AS w
  242. LIMIT k
  243. RETURN cand AS gc, w
  244. }
  245. WITH gc, sum(w) AS score
  246. RETURN
  247. gc.milvus_id AS milvus_id,
  248. gc.chunk_id AS chunk_id, // 若你的属性名不同,请改成实际字段
  249. gc.doc_id AS doc_id, // 同上
  250. score
  251. ORDER BY score DESC
  252. LIMIT $limit
  253. """
  254. async with self.driver.session(database=self.database) as session:
  255. records = await session.run(
  256. query,
  257. {
  258. "seed_ids": seed_ids,
  259. "k_per_relation": k_per_relation,
  260. "limit": limit,
  261. },
  262. )
  263. return [
  264. {
  265. "milvus_id": r["milvus_id"],
  266. "chunk_id": r["chunk_id"],
  267. "doc_id": r["doc_id"],
  268. "score": r["score"],
  269. }
  270. async for r in records
  271. ]