graph_expansion.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. import json
  2. from typing import List, Optional, Dict, Any, Iterable
  3. class AsyncGraphExpansion:
  4. _REL_ANY = ":HAS_ENTITY|HAS_CONCEPT|HAS_TOPIC"
  5. def __init__(self, driver, database: str = "neo4j"):
  6. self.driver = driver
  7. self.database = database
  8. # ========== 共现 co-occurrence ==========
  9. async def co_occurrence(
  10. self,
  11. seed_name: str,
  12. seed_label: str,
  13. other_names: Optional[List[str]] = None,
  14. limit: int = 500,
  15. ) -> List[str]:
  16. """
  17. 找与种子要素共现的 chunk
  18. """
  19. REL_ANY = ":HAS_ENTITY|HAS_CONCEPT|HAS_TOPIC" # 非弃用:只在第一个关系类型前加冒号
  20. async with self.driver.session(database=self.database) as session:
  21. if not other_names:
  22. # 一次往返:显式作用域子查询 + 非弃用关系“或”写法
  23. cypher = f"""
  24. MATCH (seed:`{seed_label}` {{name:$seed_name}})
  25. CALL (seed) {{
  26. MATCH (seed)<-[{REL_ANY}]-(gc:GraphChunk)
  27. MATCH (gc)-[{REL_ANY}]->(other)
  28. WHERE other <> seed
  29. WITH other.name AS name, count(*) AS co_freq
  30. ORDER BY co_freq DESC
  31. RETURN collect(name)[..20] AS sel_names
  32. }}
  33. WITH seed, sel_names
  34. WHERE size(sel_names) > 0
  35. MATCH (seed)<-[{REL_ANY}]-(gc2:GraphChunk)
  36. MATCH (gc2)-[{REL_ANY}]->(o2)
  37. WHERE o2.name IN sel_names
  38. RETURN DISTINCT gc2.milvus_id AS milvus_id
  39. LIMIT $limit
  40. """
  41. records = await session.run(
  42. cypher, {"seed_name": seed_name, "limit": limit}
  43. )
  44. return [str(r["milvus_id"]) async for r in records]
  45. # 已提供 other_names:直接扩展(同样使用非弃用关系“或”写法)
  46. cypher = f"""
  47. MATCH (seed:`{seed_label}` {{name:$seed_name}})
  48. <-[{REL_ANY}]-(gc:GraphChunk)
  49. MATCH (gc)-[{REL_ANY}]->(o)
  50. WHERE o.name IN $other_names
  51. RETURN DISTINCT gc.milvus_id AS milvus_id
  52. LIMIT $limit
  53. """
  54. records = await session.run(
  55. cypher,
  56. {"seed_name": seed_name, "other_names": other_names, "limit": limit},
  57. )
  58. return [str(r["milvus_id"]) async for r in records]
  59. # ========== 多种子(每个 seed 自带标签)共现(性能优化版) ==========
  60. async def co_occurrence_multi_mixed_labels(
  61. self,
  62. seeds: List[
  63. Dict[str, str]
  64. ], # [{"name": "...", "label": "Entity|Concept|Topic"}, ...]
  65. other_names: Optional[List[str]] = None,
  66. limit: int = 500,
  67. top_k: int = 20,
  68. min_support: int = 1, # 至少与多少个不同 seed 共现
  69. ) -> List[str]:
  70. if not seeds:
  71. return []
  72. # 1) 标签白名单 + 去重分桶(减少 Cypher 分支中过滤与扫描)
  73. allowed = {"Entity", "Concept", "Topic"}
  74. seed_entities = sorted({s["name"] for s in seeds if s.get("label") == "Entity"})
  75. seed_concepts = sorted(
  76. {s["name"] for s in seeds if s.get("label") == "Concept"}
  77. )
  78. seed_topics = sorted({s["name"] for s in seeds if s.get("label") == "Topic"})
  79. if not (set(x.get("label") for x in seeds) <= allowed):
  80. raise ValueError("Seed label must be one of Entity/Concept/Topic")
  81. async with self.driver.session(database=self.database) as session:
  82. # 2) 统计跨 seeds 的共现 TopN(当未显式提供 other_names 时)
  83. if not other_names:
  84. query_top = """
  85. // ========== 先收敛到与 seeds 共现的 gc + sname ==========
  86. CALL {
  87. WITH $seed_entities AS se
  88. UNWIND se AS sname
  89. MATCH (:Entity {name:sname})<-[:HAS_ENTITY]-(gc:GraphChunk)
  90. RETURN gc, sname
  91. UNION
  92. WITH $seed_concepts AS sc
  93. UNWIND sc AS sname
  94. MATCH (:Concept {name:sname})<-[:HAS_CONCEPT]-(gc:GraphChunk)
  95. RETURN gc, sname
  96. UNION
  97. WITH $seed_topics AS st
  98. UNWIND st AS sname
  99. MATCH (:Topic {name:sname})<-[:HAS_TOPIC]-(gc:GraphChunk)
  100. RETURN gc, sname
  101. }
  102. // ========== 在这些 gc 上一次性发散到 other ==========
  103. MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(other)
  104. // 排除把种子本身当作 other 的情况(基于标签分桶,仍然命中索引)
  105. WHERE NOT (
  106. (other:Entity AND other.name IN $seed_entities) OR
  107. (other:Concept AND other.name IN $seed_concepts) OR
  108. (other:Topic AND other.name IN $seed_topics)
  109. )
  110. WITH other.name AS name, gc, sname
  111. // 跨 gc 汇总频次;跨 sname 汇总支持度(命中的不同 seed 数)
  112. WITH name,
  113. count(DISTINCT gc) AS co_freq,
  114. count(DISTINCT sname) AS seed_support
  115. WHERE seed_support >= $min_support
  116. ORDER BY seed_support DESC, co_freq DESC
  117. LIMIT $top_k
  118. RETURN name
  119. """
  120. params_top = {
  121. "seed_entities": seed_entities,
  122. "seed_concepts": seed_concepts,
  123. "seed_topics": seed_topics,
  124. "top_k": top_k,
  125. "min_support": min_support,
  126. }
  127. recs = await session.run(query_top, params_top)
  128. other_names = [r["name"] async for r in recs]
  129. if not other_names:
  130. return []
  131. # 3) 回捞阶段:先确定“与 seeds 相连的 gc”,再用 other_names 索引命中
  132. # 这里用 UNWIND onames + 按标签逐类 MATCH,确保使用 name 索引,
  133. # 避免 `o.name IN $list` 走全表扫描/低效 plan。
  134. query_expand = """
  135. // 先生成候选 gc(包含任一 seed)
  136. CALL {
  137. WITH $seed_entities AS se
  138. UNWIND se AS sname
  139. MATCH (:Entity {name:sname})<-[:HAS_ENTITY]-(gc:GraphChunk)
  140. RETURN DISTINCT gc
  141. UNION
  142. WITH $seed_concepts AS sc
  143. UNWIND sc AS sname
  144. MATCH (:Concept {name:sname})<-[:HAS_CONCEPT]-(gc:GraphChunk)
  145. RETURN DISTINCT gc
  146. UNION
  147. WITH $seed_topics AS st
  148. UNWIND st AS sname
  149. MATCH (:Topic {name:sname})<-[:HAS_TOPIC]-(gc:GraphChunk)
  150. RETURN DISTINCT gc
  151. }
  152. // 用索引命中 other(UNWIND 列表,让等值匹配可走索引)
  153. UNWIND $other_names AS oname
  154. CALL {
  155. WITH oname
  156. MATCH (o:Entity {name:oname}) RETURN o
  157. UNION
  158. WITH oname
  159. MATCH (o:Concept {name:oname}) RETURN o
  160. UNION
  161. WITH oname
  162. MATCH (o:Topic {name:oname}) RETURN o
  163. }
  164. MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(o)
  165. RETURN DISTINCT gc.milvus_id AS milvus_id
  166. LIMIT $limit
  167. """
  168. params_expand = {
  169. "seed_entities": seed_entities,
  170. "seed_concepts": seed_concepts,
  171. "seed_topics": seed_topics,
  172. "other_names": other_names,
  173. "limit": limit,
  174. }
  175. recs2 = await session.run(query_expand, params_expand)
  176. return [r["milvus_id"] async for r in recs2]
  177. # ========== 路径 Path ==========
  178. async def shortest_path_chunks(
  179. self,
  180. a_name: str,
  181. a_label: str,
  182. b_name: str,
  183. b_label: str,
  184. max_len: int = 4,
  185. limit: int = 200,
  186. ) -> List[str]:
  187. """
  188. 找到两个要素之间的最短路径,并返回路径上的 chunk
  189. """
  190. query = f"""
  191. MATCH (a:`{a_label}` {{name:$a_name}}), (b:`{b_label}` {{name:$b_name}})
  192. CALL {{
  193. WITH a,b
  194. MATCH p = shortestPath(
  195. (a)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC|:BELONGS_TO*..{max_len}]-(b)
  196. )
  197. RETURN p LIMIT 1
  198. }}
  199. WITH p
  200. UNWIND [n IN nodes(p) WHERE n:GraphChunk | n] AS gc
  201. RETURN DISTINCT gc.milvus_id AS milvus_id
  202. LIMIT ${limit}
  203. """
  204. async with self.driver.session(database=self.database) as session:
  205. records = await session.run(query, {"a_name": a_name, "b_name": b_name})
  206. return [str(r["milvus_id"]) async for r in records]
  207. # ========== 扩展 Expansion ==========
  208. async def expand_candidates(
  209. self, seed_ids: List[int], k_per_relation: int = 200, limit: int = 1000
  210. ) -> List[Dict[str, Any]]:
  211. """
  212. 基于候选 milvus_id 做 1-hop 扩展(实体/概念/主题),并按权重汇总
  213. """
  214. query = """
  215. CALL {
  216. WITH $seed_ids AS seed_ids, $k_per_relation AS k
  217. // 同实体
  218. MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
  219. MATCH (gc0)-[:HAS_ENTITY]->(:Entity)<-[:HAS_ENTITY]-(cand:GraphChunk)
  220. WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
  221. WITH DISTINCT cand, 1.0 AS w
  222. LIMIT k
  223. RETURN cand AS gc, w
  224. UNION
  225. // 同概念
  226. WITH seed_ids, k
  227. MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
  228. MATCH (gc0)-[:HAS_CONCEPT]->(:Concept)<-[:HAS_CONCEPT]-(cand:GraphChunk)
  229. WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
  230. WITH DISTINCT cand, 0.7 AS w
  231. LIMIT k
  232. RETURN cand AS gc, w
  233. UNION
  234. // 同主题
  235. WITH seed_ids, k
  236. MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
  237. MATCH (gc0)-[:HAS_TOPIC]->(:Topic)<-[:HAS_TOPIC]-(cand:GraphChunk)
  238. WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
  239. WITH DISTINCT cand, 0.6 AS w
  240. LIMIT k
  241. RETURN cand AS gc, w
  242. }
  243. WITH gc, sum(w) AS score
  244. RETURN
  245. gc.milvus_id AS milvus_id,
  246. gc.chunk_id AS chunk_id, // 若你的属性名不同,请改成实际字段
  247. gc.doc_id AS doc_id, // 同上
  248. score
  249. ORDER BY score DESC
  250. LIMIT $limit
  251. """
  252. async with self.driver.session(database=self.database) as session:
  253. records = await session.run(
  254. query,
  255. {
  256. "seed_ids": seed_ids,
  257. "k_per_relation": k_per_relation,
  258. "limit": limit,
  259. },
  260. )
  261. return [
  262. {
  263. "milvus_id": r["milvus_id"],
  264. "chunk_id": r["chunk_id"],
  265. "doc_id": r["doc_id"],
  266. "score": r["score"],
  267. }
  268. async for r in records
  269. ]