|
@@ -0,0 +1,291 @@
|
|
|
+import json
|
|
|
+from typing import List, Optional, Dict, Any, Iterable
|
|
|
+
|
|
|
+
|
|
|
+class AsyncGraphExpansion:
|
|
|
+ _REL_ANY = ":HAS_ENTITY|HAS_CONCEPT|HAS_TOPIC"
|
|
|
+
|
|
|
+ def __init__(self, driver, database: str = "neo4j"):
|
|
|
+ self.driver = driver
|
|
|
+ self.database = database
|
|
|
+
|
|
|
+ # ========== 共现 co-occurrence ==========
|
|
|
+ async def co_occurrence(
|
|
|
+ self,
|
|
|
+ seed_name: str,
|
|
|
+ seed_label: str,
|
|
|
+ other_names: Optional[List[str]] = None,
|
|
|
+ limit: int = 500,
|
|
|
+ ) -> List[str]:
|
|
|
+ """
|
|
|
+ 找与种子要素共现的 chunk
|
|
|
+ """
|
|
|
+ REL_ANY = ":HAS_ENTITY|HAS_CONCEPT|HAS_TOPIC" # 非弃用:只在第一个关系类型前加冒号
|
|
|
+
|
|
|
+ async with self.driver.session(database=self.database) as session:
|
|
|
+ if not other_names:
|
|
|
+ # 一次往返:显式作用域子查询 + 非弃用关系“或”写法
|
|
|
+ cypher = f"""
|
|
|
+ MATCH (seed:`{seed_label}` {{name:$seed_name}})
|
|
|
+ CALL (seed) {{
|
|
|
+ MATCH (seed)<-[{REL_ANY}]-(gc:GraphChunk)
|
|
|
+ MATCH (gc)-[{REL_ANY}]->(other)
|
|
|
+ WHERE other <> seed
|
|
|
+ WITH other.name AS name, count(*) AS co_freq
|
|
|
+ ORDER BY co_freq DESC
|
|
|
+ RETURN collect(name)[..20] AS sel_names
|
|
|
+ }}
|
|
|
+ WITH seed, sel_names
|
|
|
+ WHERE size(sel_names) > 0
|
|
|
+ MATCH (seed)<-[{REL_ANY}]-(gc2:GraphChunk)
|
|
|
+ MATCH (gc2)-[{REL_ANY}]->(o2)
|
|
|
+ WHERE o2.name IN sel_names
|
|
|
+ RETURN DISTINCT gc2.milvus_id AS milvus_id
|
|
|
+ LIMIT $limit
|
|
|
+ """
|
|
|
+ records = await session.run(
|
|
|
+ cypher, {"seed_name": seed_name, "limit": limit}
|
|
|
+ )
|
|
|
+ return [str(r["milvus_id"]) async for r in records]
|
|
|
+
|
|
|
+ # 已提供 other_names:直接扩展(同样使用非弃用关系“或”写法)
|
|
|
+ cypher = f"""
|
|
|
+ MATCH (seed:`{seed_label}` {{name:$seed_name}})
|
|
|
+ <-[{REL_ANY}]-(gc:GraphChunk)
|
|
|
+ MATCH (gc)-[{REL_ANY}]->(o)
|
|
|
+ WHERE o.name IN $other_names
|
|
|
+ RETURN DISTINCT gc.milvus_id AS milvus_id
|
|
|
+ LIMIT $limit
|
|
|
+ """
|
|
|
+ records = await session.run(
|
|
|
+ cypher,
|
|
|
+ {"seed_name": seed_name, "other_names": other_names, "limit": limit},
|
|
|
+ )
|
|
|
+ return [str(r["milvus_id"]) async for r in records]
|
|
|
+
|
|
|
+ # ========== 多种子(每个 seed 自带标签)共现(性能优化版) ==========
|
|
|
+ async def co_occurrence_multi_mixed_labels(
|
|
|
+ self,
|
|
|
+ seeds: List[
|
|
|
+ Dict[str, str]
|
|
|
+ ], # [{"name": "...", "label": "Entity|Concept|Topic"}, ...]
|
|
|
+ other_names: Optional[List[str]] = None,
|
|
|
+ limit: int = 500,
|
|
|
+ top_k: int = 20,
|
|
|
+ min_support: int = 1, # 至少与多少个不同 seed 共现
|
|
|
+ ) -> List[str]:
|
|
|
+ if not seeds:
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 1) 标签白名单 + 去重分桶(减少 Cypher 分支中过滤与扫描)
|
|
|
+ allowed = {"Entity", "Concept", "Topic"}
|
|
|
+ seed_entities = sorted({s["name"] for s in seeds if s.get("label") == "Entity"})
|
|
|
+ seed_concepts = sorted(
|
|
|
+ {s["name"] for s in seeds if s.get("label") == "Concept"}
|
|
|
+ )
|
|
|
+ seed_topics = sorted({s["name"] for s in seeds if s.get("label") == "Topic"})
|
|
|
+ if not (set(x.get("label") for x in seeds) <= allowed):
|
|
|
+ raise ValueError("Seed label must be one of Entity/Concept/Topic")
|
|
|
+
|
|
|
+ async with self.driver.session(database=self.database) as session:
|
|
|
+ # 2) 统计跨 seeds 的共现 TopN(当未显式提供 other_names 时)
|
|
|
+ if not other_names:
|
|
|
+ query_top = """
|
|
|
+ // ========== 先收敛到与 seeds 共现的 gc + sname ==========
|
|
|
+ CALL {
|
|
|
+ WITH $seed_entities AS se
|
|
|
+ UNWIND se AS sname
|
|
|
+ MATCH (:Entity {name:sname})<-[:HAS_ENTITY]-(gc:GraphChunk)
|
|
|
+ RETURN gc, sname
|
|
|
+ UNION
|
|
|
+ WITH $seed_concepts AS sc
|
|
|
+ UNWIND sc AS sname
|
|
|
+ MATCH (:Concept {name:sname})<-[:HAS_CONCEPT]-(gc:GraphChunk)
|
|
|
+ RETURN gc, sname
|
|
|
+ UNION
|
|
|
+ WITH $seed_topics AS st
|
|
|
+ UNWIND st AS sname
|
|
|
+ MATCH (:Topic {name:sname})<-[:HAS_TOPIC]-(gc:GraphChunk)
|
|
|
+ RETURN gc, sname
|
|
|
+ }
|
|
|
+
|
|
|
+ // ========== 在这些 gc 上一次性发散到 other ==========
|
|
|
+ MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(other)
|
|
|
+ // 排除把种子本身当作 other 的情况(基于标签分桶,仍然命中索引)
|
|
|
+ WHERE NOT (
|
|
|
+ (other:Entity AND other.name IN $seed_entities) OR
|
|
|
+ (other:Concept AND other.name IN $seed_concepts) OR
|
|
|
+ (other:Topic AND other.name IN $seed_topics)
|
|
|
+ )
|
|
|
+ WITH other.name AS name, gc, sname
|
|
|
+ // 跨 gc 汇总频次;跨 sname 汇总支持度(命中的不同 seed 数)
|
|
|
+ WITH name,
|
|
|
+ count(DISTINCT gc) AS co_freq,
|
|
|
+ count(DISTINCT sname) AS seed_support
|
|
|
+ WHERE seed_support >= $min_support
|
|
|
+ ORDER BY seed_support DESC, co_freq DESC
|
|
|
+ LIMIT $top_k
|
|
|
+ RETURN name
|
|
|
+ """
|
|
|
+
|
|
|
+ params_top = {
|
|
|
+ "seed_entities": seed_entities,
|
|
|
+ "seed_concepts": seed_concepts,
|
|
|
+ "seed_topics": seed_topics,
|
|
|
+ "top_k": top_k,
|
|
|
+ "min_support": min_support,
|
|
|
+ }
|
|
|
+ recs = await session.run(query_top, params_top)
|
|
|
+ other_names = [r["name"] async for r in recs]
|
|
|
+ if not other_names:
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 3) 回捞阶段:先确定“与 seeds 相连的 gc”,再用 other_names 索引命中
|
|
|
+ # 这里用 UNWIND onames + 按标签逐类 MATCH,确保使用 name 索引,
|
|
|
+ # 避免 `o.name IN $list` 走全表扫描/低效 plan。
|
|
|
+ query_expand = """
|
|
|
+ // 先生成候选 gc(包含任一 seed)
|
|
|
+ CALL {
|
|
|
+ WITH $seed_entities AS se
|
|
|
+ UNWIND se AS sname
|
|
|
+ MATCH (:Entity {name:sname})<-[:HAS_ENTITY]-(gc:GraphChunk)
|
|
|
+ RETURN DISTINCT gc
|
|
|
+ UNION
|
|
|
+ WITH $seed_concepts AS sc
|
|
|
+ UNWIND sc AS sname
|
|
|
+ MATCH (:Concept {name:sname})<-[:HAS_CONCEPT]-(gc:GraphChunk)
|
|
|
+ RETURN DISTINCT gc
|
|
|
+ UNION
|
|
|
+ WITH $seed_topics AS st
|
|
|
+ UNWIND st AS sname
|
|
|
+ MATCH (:Topic {name:sname})<-[:HAS_TOPIC]-(gc:GraphChunk)
|
|
|
+ RETURN DISTINCT gc
|
|
|
+ }
|
|
|
+
|
|
|
+ // 用索引命中 other(UNWIND 列表,让等值匹配可走索引)
|
|
|
+ UNWIND $other_names AS oname
|
|
|
+ CALL {
|
|
|
+ WITH oname
|
|
|
+ MATCH (o:Entity {name:oname}) RETURN o
|
|
|
+ UNION
|
|
|
+ WITH oname
|
|
|
+ MATCH (o:Concept {name:oname}) RETURN o
|
|
|
+ UNION
|
|
|
+ WITH oname
|
|
|
+ MATCH (o:Topic {name:oname}) RETURN o
|
|
|
+ }
|
|
|
+
|
|
|
+ MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(o)
|
|
|
+ RETURN DISTINCT gc.milvus_id AS milvus_id
|
|
|
+ LIMIT $limit
|
|
|
+ """
|
|
|
+
|
|
|
+ params_expand = {
|
|
|
+ "seed_entities": seed_entities,
|
|
|
+ "seed_concepts": seed_concepts,
|
|
|
+ "seed_topics": seed_topics,
|
|
|
+ "other_names": other_names,
|
|
|
+ "limit": limit,
|
|
|
+ }
|
|
|
+ recs2 = await session.run(query_expand, params_expand)
|
|
|
+ return [r["milvus_id"] async for r in recs2]
|
|
|
+
|
|
|
+ # ========== 路径 Path ==========
|
|
|
+ async def shortest_path_chunks(
|
|
|
+ self,
|
|
|
+ a_name: str,
|
|
|
+ a_label: str,
|
|
|
+ b_name: str,
|
|
|
+ b_label: str,
|
|
|
+ max_len: int = 4,
|
|
|
+ limit: int = 200,
|
|
|
+ ) -> List[str]:
|
|
|
+ """
|
|
|
+ 找到两个要素之间的最短路径,并返回路径上的 chunk
|
|
|
+ """
|
|
|
+ query = f"""
|
|
|
+ MATCH (a:`{a_label}` {{name:$a_name}}), (b:`{b_label}` {{name:$b_name}})
|
|
|
+ CALL {{
|
|
|
+ WITH a,b
|
|
|
+ MATCH p = shortestPath(
|
|
|
+ (a)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC|:BELONGS_TO*..{max_len}]-(b)
|
|
|
+ )
|
|
|
+ RETURN p LIMIT 1
|
|
|
+ }}
|
|
|
+ WITH p
|
|
|
+ UNWIND [n IN nodes(p) WHERE n:GraphChunk | n] AS gc
|
|
|
+ RETURN DISTINCT gc.milvus_id AS milvus_id
|
|
|
+ LIMIT ${limit}
|
|
|
+ """
|
|
|
+ async with self.driver.session(database=self.database) as session:
|
|
|
+ records = await session.run(query, {"a_name": a_name, "b_name": b_name})
|
|
|
+ return [str(r["milvus_id"]) async for r in records]
|
|
|
+
|
|
|
+ # ========== 扩展 Expansion ==========
|
|
|
+ async def expand_candidates(
|
|
|
+ self, seed_ids: List[int], k_per_relation: int = 200, limit: int = 1000
|
|
|
+ ) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 基于候选 milvus_id 做 1-hop 扩展(实体/概念/主题),并按权重汇总
|
|
|
+ """
|
|
|
+ query = """
|
|
|
+ CALL {
|
|
|
+ WITH $seed_ids AS seed_ids, $k_per_relation AS k
|
|
|
+
|
|
|
+ // 同实体
|
|
|
+ MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
|
|
|
+ MATCH (gc0)-[:HAS_ENTITY]->(:Entity)<-[:HAS_ENTITY]-(cand:GraphChunk)
|
|
|
+ WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
|
|
|
+ WITH DISTINCT cand, 1.0 AS w
|
|
|
+ LIMIT k
|
|
|
+ RETURN cand AS gc, w
|
|
|
+
|
|
|
+ UNION
|
|
|
+
|
|
|
+ // 同概念
|
|
|
+ WITH seed_ids, k
|
|
|
+ MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
|
|
|
+ MATCH (gc0)-[:HAS_CONCEPT]->(:Concept)<-[:HAS_CONCEPT]-(cand:GraphChunk)
|
|
|
+ WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
|
|
|
+ WITH DISTINCT cand, 0.7 AS w
|
|
|
+ LIMIT k
|
|
|
+ RETURN cand AS gc, w
|
|
|
+
|
|
|
+ UNION
|
|
|
+
|
|
|
+ // 同主题
|
|
|
+ WITH seed_ids, k
|
|
|
+ MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
|
|
|
+ MATCH (gc0)-[:HAS_TOPIC]->(:Topic)<-[:HAS_TOPIC]-(cand:GraphChunk)
|
|
|
+ WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
|
|
|
+ WITH DISTINCT cand, 0.6 AS w
|
|
|
+ LIMIT k
|
|
|
+ RETURN cand AS gc, w
|
|
|
+ }
|
|
|
+ WITH gc, sum(w) AS score
|
|
|
+ RETURN
|
|
|
+ gc.milvus_id AS milvus_id,
|
|
|
+ gc.chunk_id AS chunk_id, // 若你的属性名不同,请改成实际字段
|
|
|
+ gc.doc_id AS doc_id, // 同上
|
|
|
+ score
|
|
|
+ ORDER BY score DESC
|
|
|
+ LIMIT $limit
|
|
|
+ """
|
|
|
+ async with self.driver.session(database=self.database) as session:
|
|
|
+ records = await session.run(
|
|
|
+ query,
|
|
|
+ {
|
|
|
+ "seed_ids": seed_ids,
|
|
|
+ "k_per_relation": k_per_relation,
|
|
|
+ "limit": limit,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ return [
|
|
|
+ {
|
|
|
+ "milvus_id": r["milvus_id"],
|
|
|
+ "chunk_id": r["chunk_id"],
|
|
|
+ "doc_id": r["doc_id"],
|
|
|
+ "score": r["score"],
|
|
|
+ }
|
|
|
+ async for r in records
|
|
|
+ ]
|