import json from typing import List, Optional, Dict, Any, Iterable class AsyncGraphExpansion: _REL_ANY = ":HAS_ENTITY|HAS_CONCEPT|HAS_TOPIC" def __init__(self, driver, database: str = "neo4j"): self.driver = driver self.database = database # ========== 共现 co-occurrence ========== async def co_occurrence( self, seed_name: str, seed_label: str, other_names: Optional[List[str]] = None, limit: int = 500, ) -> List[str]: """ 找与种子要素共现的 chunk """ REL_ANY = ":HAS_ENTITY|HAS_CONCEPT|HAS_TOPIC" # 非弃用:只在第一个关系类型前加冒号 async with self.driver.session(database=self.database) as session: if not other_names: # 一次往返:显式作用域子查询 + 非弃用关系“或”写法 cypher = f""" MATCH (seed:`{seed_label}` {{name:$seed_name}}) CALL (seed) {{ MATCH (seed)<-[{REL_ANY}]-(gc:GraphChunk) MATCH (gc)-[{REL_ANY}]->(other) WHERE other <> seed WITH other.name AS name, count(*) AS co_freq ORDER BY co_freq DESC RETURN collect(name)[..20] AS sel_names }} WITH seed, sel_names WHERE size(sel_names) > 0 MATCH (seed)<-[{REL_ANY}]-(gc2:GraphChunk) MATCH (gc2)-[{REL_ANY}]->(o2) WHERE o2.name IN sel_names RETURN DISTINCT gc2.milvus_id AS milvus_id LIMIT $limit """ records = await session.run( cypher, {"seed_name": seed_name, "limit": limit} ) return [str(r["milvus_id"]) async for r in records] # 已提供 other_names:直接扩展(同样使用非弃用关系“或”写法) cypher = f""" MATCH (seed:`{seed_label}` {{name:$seed_name}}) <-[{REL_ANY}]-(gc:GraphChunk) MATCH (gc)-[{REL_ANY}]->(o) WHERE o.name IN $other_names RETURN DISTINCT gc.milvus_id AS milvus_id LIMIT $limit """ records = await session.run( cypher, {"seed_name": seed_name, "other_names": other_names, "limit": limit}, ) return [str(r["milvus_id"]) async for r in records] # ========== 多种子(每个 seed 自带标签)共现(性能优化版) ========== async def co_occurrence_multi_mixed_labels( self, seeds: List[ Dict[str, str] ], # [{"name": "...", "label": "Entity|Concept|Topic"}, ...] other_names: Optional[List[str]] = None, limit: int = 500, top_k: int = 20, min_support: int = 1, # 至少与多少个不同 seed 共现 ) -> List[str]: if not seeds: return [] # 1) 标签白名单 + 去重分桶(减少 Cypher 分支中过滤与扫描) allowed = {"Entity", "Concept", "Topic"} seed_entities = sorted({s["name"] for s in seeds if s.get("label") == "Entity"}) seed_concepts = sorted( {s["name"] for s in seeds if s.get("label") == "Concept"} ) seed_topics = sorted({s["name"] for s in seeds if s.get("label") == "Topic"}) if not (set(x.get("label") for x in seeds) <= allowed): raise ValueError("Seed label must be one of Entity/Concept/Topic") async with self.driver.session(database=self.database) as session: # 2) 统计跨 seeds 的共现 TopN(当未显式提供 other_names 时) if not other_names: query_top = """ // ========== 先收敛到与 seeds 共现的 gc + sname ========== CALL { WITH $seed_entities AS se UNWIND se AS sname MATCH (:Entity {name:sname})<-[:HAS_ENTITY]-(gc:GraphChunk) RETURN gc, sname UNION WITH $seed_concepts AS sc UNWIND sc AS sname MATCH (:Concept {name:sname})<-[:HAS_CONCEPT]-(gc:GraphChunk) RETURN gc, sname UNION WITH $seed_topics AS st UNWIND st AS sname MATCH (:Topic {name:sname})<-[:HAS_TOPIC]-(gc:GraphChunk) RETURN gc, sname } // ========== 在这些 gc 上一次性发散到 other ========== MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(other) // 排除把种子本身当作 other 的情况(基于标签分桶,仍然命中索引) WHERE NOT ( (other:Entity AND other.name IN $seed_entities) OR (other:Concept AND other.name IN $seed_concepts) OR (other:Topic AND other.name IN $seed_topics) ) WITH other.name AS name, gc, sname // 跨 gc 汇总频次;跨 sname 汇总支持度(命中的不同 seed 数) WITH name, count(DISTINCT gc) AS co_freq, count(DISTINCT sname) AS seed_support WHERE seed_support >= $min_support ORDER BY seed_support DESC, co_freq DESC LIMIT $top_k RETURN name """ params_top = { "seed_entities": seed_entities, "seed_concepts": seed_concepts, "seed_topics": seed_topics, "top_k": top_k, "min_support": min_support, } recs = await session.run(query_top, params_top) other_names = [r["name"] async for r in recs] if not other_names: return [] # 3) 回捞阶段:先确定“与 seeds 相连的 gc”,再用 other_names 索引命中 # 这里用 UNWIND onames + 按标签逐类 MATCH,确保使用 name 索引, # 避免 `o.name IN $list` 走全表扫描/低效 plan。 query_expand = """ // 先生成候选 gc(包含任一 seed) CALL { WITH $seed_entities AS se UNWIND se AS sname MATCH (:Entity {name:sname})<-[:HAS_ENTITY]-(gc:GraphChunk) RETURN DISTINCT gc UNION WITH $seed_concepts AS sc UNWIND sc AS sname MATCH (:Concept {name:sname})<-[:HAS_CONCEPT]-(gc:GraphChunk) RETURN DISTINCT gc UNION WITH $seed_topics AS st UNWIND st AS sname MATCH (:Topic {name:sname})<-[:HAS_TOPIC]-(gc:GraphChunk) RETURN DISTINCT gc } // 用索引命中 other(UNWIND 列表,让等值匹配可走索引) UNWIND $other_names AS oname CALL { WITH oname MATCH (o:Entity {name:oname}) RETURN o UNION WITH oname MATCH (o:Concept {name:oname}) RETURN o UNION WITH oname MATCH (o:Topic {name:oname}) RETURN o } MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(o) RETURN DISTINCT gc.milvus_id AS milvus_id LIMIT $limit """ params_expand = { "seed_entities": seed_entities, "seed_concepts": seed_concepts, "seed_topics": seed_topics, "other_names": other_names, "limit": limit, } recs2 = await session.run(query_expand, params_expand) return [r["milvus_id"] async for r in recs2] # ========== 路径 Path ========== async def shortest_path_chunks( self, a_name: str, a_label: str, b_name: str, b_label: str, max_len: int = 4, limit: int = 200, ) -> List[str]: """ 找到两个要素之间的最短路径,并返回路径上的 chunk """ query = f""" MATCH (a:`{a_label}` {{name:$a_name}}), (b:`{b_label}` {{name:$b_name}}) CALL {{ WITH a,b MATCH p = shortestPath( (a)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC|:BELONGS_TO*..{max_len}]-(b) ) RETURN p LIMIT 1 }} WITH p UNWIND [n IN nodes(p) WHERE n:GraphChunk | n] AS gc RETURN DISTINCT gc.milvus_id AS milvus_id LIMIT ${limit} """ async with self.driver.session(database=self.database) as session: records = await session.run(query, {"a_name": a_name, "b_name": b_name}) return [str(r["milvus_id"]) async for r in records] # ========== 扩展 Expansion ========== async def expand_candidates( self, seed_ids: List[int], k_per_relation: int = 200, limit: int = 1000 ) -> List[Dict[str, Any]]: """ 基于候选 milvus_id 做 1-hop 扩展(实体/概念/主题),并按权重汇总 """ query = """ CALL { WITH $seed_ids AS seed_ids, $k_per_relation AS k // 同实体 MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids MATCH (gc0)-[:HAS_ENTITY]->(:Entity)<-[:HAS_ENTITY]-(cand:GraphChunk) WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids WITH DISTINCT cand, 1.0 AS w LIMIT k RETURN cand AS gc, w UNION // 同概念 WITH seed_ids, k MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids MATCH (gc0)-[:HAS_CONCEPT]->(:Concept)<-[:HAS_CONCEPT]-(cand:GraphChunk) WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids WITH DISTINCT cand, 0.7 AS w LIMIT k RETURN cand AS gc, w UNION // 同主题 WITH seed_ids, k MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids MATCH (gc0)-[:HAS_TOPIC]->(:Topic)<-[:HAS_TOPIC]-(cand:GraphChunk) WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids WITH DISTINCT cand, 0.6 AS w LIMIT k RETURN cand AS gc, w } WITH gc, sum(w) AS score RETURN gc.milvus_id AS milvus_id, gc.chunk_id AS chunk_id, // 若你的属性名不同,请改成实际字段 gc.doc_id AS doc_id, // 同上 score ORDER BY score DESC LIMIT $limit """ async with self.driver.session(database=self.database) as session: records = await session.run( query, { "seed_ids": seed_ids, "k_per_relation": k_per_relation, "limit": limit, }, ) return [ { "milvus_id": r["milvus_id"], "chunk_id": r["chunk_id"], "doc_id": r["doc_id"], "score": r["score"], } async for r in records ]