123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291 |
- import json
- from typing import List, Optional, Dict, Any, Iterable
- class AsyncGraphExpansion:
- _REL_ANY = ":HAS_ENTITY|HAS_CONCEPT|HAS_TOPIC"
- def __init__(self, driver, database: str = "neo4j"):
- self.driver = driver
- self.database = database
- # ========== 共现 co-occurrence ==========
- async def co_occurrence(
- self,
- seed_name: str,
- seed_label: str,
- other_names: Optional[List[str]] = None,
- limit: int = 500,
- ) -> List[str]:
- """
- 找与种子要素共现的 chunk
- """
- REL_ANY = ":HAS_ENTITY|HAS_CONCEPT|HAS_TOPIC" # 非弃用:只在第一个关系类型前加冒号
- async with self.driver.session(database=self.database) as session:
- if not other_names:
- # 一次往返:显式作用域子查询 + 非弃用关系“或”写法
- cypher = f"""
- MATCH (seed:`{seed_label}` {{name:$seed_name}})
- CALL (seed) {{
- MATCH (seed)<-[{REL_ANY}]-(gc:GraphChunk)
- MATCH (gc)-[{REL_ANY}]->(other)
- WHERE other <> seed
- WITH other.name AS name, count(*) AS co_freq
- ORDER BY co_freq DESC
- RETURN collect(name)[..20] AS sel_names
- }}
- WITH seed, sel_names
- WHERE size(sel_names) > 0
- MATCH (seed)<-[{REL_ANY}]-(gc2:GraphChunk)
- MATCH (gc2)-[{REL_ANY}]->(o2)
- WHERE o2.name IN sel_names
- RETURN DISTINCT gc2.milvus_id AS milvus_id
- LIMIT $limit
- """
- records = await session.run(
- cypher, {"seed_name": seed_name, "limit": limit}
- )
- return [str(r["milvus_id"]) async for r in records]
- # 已提供 other_names:直接扩展(同样使用非弃用关系“或”写法)
- cypher = f"""
- MATCH (seed:`{seed_label}` {{name:$seed_name}})
- <-[{REL_ANY}]-(gc:GraphChunk)
- MATCH (gc)-[{REL_ANY}]->(o)
- WHERE o.name IN $other_names
- RETURN DISTINCT gc.milvus_id AS milvus_id
- LIMIT $limit
- """
- records = await session.run(
- cypher,
- {"seed_name": seed_name, "other_names": other_names, "limit": limit},
- )
- return [str(r["milvus_id"]) async for r in records]
- # ========== 多种子(每个 seed 自带标签)共现(性能优化版) ==========
- async def co_occurrence_multi_mixed_labels(
- self,
- seeds: List[
- Dict[str, str]
- ], # [{"name": "...", "label": "Entity|Concept|Topic"}, ...]
- other_names: Optional[List[str]] = None,
- limit: int = 500,
- top_k: int = 20,
- min_support: int = 1, # 至少与多少个不同 seed 共现
- ) -> List[str]:
- if not seeds:
- return []
- # 1) 标签白名单 + 去重分桶(减少 Cypher 分支中过滤与扫描)
- allowed = {"Entity", "Concept", "Topic"}
- seed_entities = sorted({s["name"] for s in seeds if s.get("label") == "Entity"})
- seed_concepts = sorted(
- {s["name"] for s in seeds if s.get("label") == "Concept"}
- )
- seed_topics = sorted({s["name"] for s in seeds if s.get("label") == "Topic"})
- if not (set(x.get("label") for x in seeds) <= allowed):
- raise ValueError("Seed label must be one of Entity/Concept/Topic")
- async with self.driver.session(database=self.database) as session:
- # 2) 统计跨 seeds 的共现 TopN(当未显式提供 other_names 时)
- if not other_names:
- query_top = """
- // ========== 先收敛到与 seeds 共现的 gc + sname ==========
- CALL {
- WITH $seed_entities AS se
- UNWIND se AS sname
- MATCH (:Entity {name:sname})<-[:HAS_ENTITY]-(gc:GraphChunk)
- RETURN gc, sname
- UNION
- WITH $seed_concepts AS sc
- UNWIND sc AS sname
- MATCH (:Concept {name:sname})<-[:HAS_CONCEPT]-(gc:GraphChunk)
- RETURN gc, sname
- UNION
- WITH $seed_topics AS st
- UNWIND st AS sname
- MATCH (:Topic {name:sname})<-[:HAS_TOPIC]-(gc:GraphChunk)
- RETURN gc, sname
- }
- // ========== 在这些 gc 上一次性发散到 other ==========
- MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(other)
- // 排除把种子本身当作 other 的情况(基于标签分桶,仍然命中索引)
- WHERE NOT (
- (other:Entity AND other.name IN $seed_entities) OR
- (other:Concept AND other.name IN $seed_concepts) OR
- (other:Topic AND other.name IN $seed_topics)
- )
- WITH other.name AS name, gc, sname
- // 跨 gc 汇总频次;跨 sname 汇总支持度(命中的不同 seed 数)
- WITH name,
- count(DISTINCT gc) AS co_freq,
- count(DISTINCT sname) AS seed_support
- WHERE seed_support >= $min_support
- ORDER BY seed_support DESC, co_freq DESC
- LIMIT $top_k
- RETURN name
- """
- params_top = {
- "seed_entities": seed_entities,
- "seed_concepts": seed_concepts,
- "seed_topics": seed_topics,
- "top_k": top_k,
- "min_support": min_support,
- }
- recs = await session.run(query_top, params_top)
- other_names = [r["name"] async for r in recs]
- if not other_names:
- return []
- # 3) 回捞阶段:先确定“与 seeds 相连的 gc”,再用 other_names 索引命中
- # 这里用 UNWIND onames + 按标签逐类 MATCH,确保使用 name 索引,
- # 避免 `o.name IN $list` 走全表扫描/低效 plan。
- query_expand = """
- // 先生成候选 gc(包含任一 seed)
- CALL {
- WITH $seed_entities AS se
- UNWIND se AS sname
- MATCH (:Entity {name:sname})<-[:HAS_ENTITY]-(gc:GraphChunk)
- RETURN DISTINCT gc
- UNION
- WITH $seed_concepts AS sc
- UNWIND sc AS sname
- MATCH (:Concept {name:sname})<-[:HAS_CONCEPT]-(gc:GraphChunk)
- RETURN DISTINCT gc
- UNION
- WITH $seed_topics AS st
- UNWIND st AS sname
- MATCH (:Topic {name:sname})<-[:HAS_TOPIC]-(gc:GraphChunk)
- RETURN DISTINCT gc
- }
- // 用索引命中 other(UNWIND 列表,让等值匹配可走索引)
- UNWIND $other_names AS oname
- CALL {
- WITH oname
- MATCH (o:Entity {name:oname}) RETURN o
- UNION
- WITH oname
- MATCH (o:Concept {name:oname}) RETURN o
- UNION
- WITH oname
- MATCH (o:Topic {name:oname}) RETURN o
- }
- MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(o)
- RETURN DISTINCT gc.milvus_id AS milvus_id
- LIMIT $limit
- """
- params_expand = {
- "seed_entities": seed_entities,
- "seed_concepts": seed_concepts,
- "seed_topics": seed_topics,
- "other_names": other_names,
- "limit": limit,
- }
- recs2 = await session.run(query_expand, params_expand)
- return [r["milvus_id"] async for r in recs2]
- # ========== 路径 Path ==========
- async def shortest_path_chunks(
- self,
- a_name: str,
- a_label: str,
- b_name: str,
- b_label: str,
- max_len: int = 4,
- limit: int = 200,
- ) -> List[str]:
- """
- 找到两个要素之间的最短路径,并返回路径上的 chunk
- """
- query = f"""
- MATCH (a:`{a_label}` {{name:$a_name}}), (b:`{b_label}` {{name:$b_name}})
- CALL {{
- WITH a,b
- MATCH p = shortestPath(
- (a)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC|:BELONGS_TO*..{max_len}]-(b)
- )
- RETURN p LIMIT 1
- }}
- WITH p
- UNWIND [n IN nodes(p) WHERE n:GraphChunk | n] AS gc
- RETURN DISTINCT gc.milvus_id AS milvus_id
- LIMIT ${limit}
- """
- async with self.driver.session(database=self.database) as session:
- records = await session.run(query, {"a_name": a_name, "b_name": b_name})
- return [str(r["milvus_id"]) async for r in records]
- # ========== 扩展 Expansion ==========
- async def expand_candidates(
- self, seed_ids: List[int], k_per_relation: int = 200, limit: int = 1000
- ) -> List[Dict[str, Any]]:
- """
- 基于候选 milvus_id 做 1-hop 扩展(实体/概念/主题),并按权重汇总
- """
- query = """
- CALL {
- WITH $seed_ids AS seed_ids, $k_per_relation AS k
-
- // 同实体
- MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
- MATCH (gc0)-[:HAS_ENTITY]->(:Entity)<-[:HAS_ENTITY]-(cand:GraphChunk)
- WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
- WITH DISTINCT cand, 1.0 AS w
- LIMIT k
- RETURN cand AS gc, w
-
- UNION
-
- // 同概念
- WITH seed_ids, k
- MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
- MATCH (gc0)-[:HAS_CONCEPT]->(:Concept)<-[:HAS_CONCEPT]-(cand:GraphChunk)
- WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
- WITH DISTINCT cand, 0.7 AS w
- LIMIT k
- RETURN cand AS gc, w
-
- UNION
-
- // 同主题
- WITH seed_ids, k
- MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
- MATCH (gc0)-[:HAS_TOPIC]->(:Topic)<-[:HAS_TOPIC]-(cand:GraphChunk)
- WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
- WITH DISTINCT cand, 0.6 AS w
- LIMIT k
- RETURN cand AS gc, w
- }
- WITH gc, sum(w) AS score
- RETURN
- gc.milvus_id AS milvus_id,
- gc.chunk_id AS chunk_id, // 若你的属性名不同,请改成实际字段
- gc.doc_id AS doc_id, // 同上
- score
- ORDER BY score DESC
- LIMIT $limit
- """
- async with self.driver.session(database=self.database) as session:
- records = await session.run(
- query,
- {
- "seed_ids": seed_ids,
- "k_per_relation": k_per_relation,
- "limit": limit,
- },
- )
- return [
- {
- "milvus_id": r["milvus_id"],
- "chunk_id": r["chunk_id"],
- "doc_id": r["doc_id"],
- "score": r["score"],
- }
- async for r in records
- ]
|