|
@@ -1,4 +1,5 @@
|
|
-from typing import List, Dict, Any, Optional
|
|
|
|
|
|
+import json
|
|
|
|
+from typing import List, Optional, Dict, Any
|
|
|
|
|
|
|
|
|
|
class AsyncGraphExpansion:
|
|
class AsyncGraphExpansion:
|
|
@@ -35,6 +36,7 @@ class AsyncGraphExpansion:
|
|
"""
|
|
"""
|
|
records = await session.run(query_top, {"seed_name": seed_name})
|
|
records = await session.run(query_top, {"seed_name": seed_name})
|
|
other_names = [r["name"] async for r in records]
|
|
other_names = [r["name"] async for r in records]
|
|
|
|
+ print(other_names)
|
|
|
|
|
|
# 根据共现要素回捞 chunk
|
|
# 根据共现要素回捞 chunk
|
|
query_expand = f"""
|
|
query_expand = f"""
|
|
@@ -49,7 +51,134 @@ class AsyncGraphExpansion:
|
|
query_expand,
|
|
query_expand,
|
|
{"seed_name": seed_name, "other_names": other_names, "limit": limit},
|
|
{"seed_name": seed_name, "other_names": other_names, "limit": limit},
|
|
)
|
|
)
|
|
- return [r["milvus_id"] async for r in records]
|
|
|
|
|
|
+ return [str(r["milvus_id"]) async for r in records]
|
|
|
|
+
|
|
|
|
+ # ========== 多种子(每个 seed 自带标签)共现(性能优化版) ==========
|
|
|
|
+ async def co_occurrence_multi_mixed_labels(
|
|
|
|
+ self,
|
|
|
|
+ seeds: List[
|
|
|
|
+ Dict[str, str]
|
|
|
|
+ ], # [{"name": "...", "label": "Entity|Concept|Topic"}, ...]
|
|
|
|
+ other_names: Optional[List[str]] = None,
|
|
|
|
+ limit: int = 500,
|
|
|
|
+ top_k: int = 20,
|
|
|
|
+ min_support: int = 1, # 至少与多少个不同 seed 共现
|
|
|
|
+ ) -> List[str]:
|
|
|
|
+ if not seeds:
|
|
|
|
+ return []
|
|
|
|
+
|
|
|
|
+ # 1) 标签白名单 + 去重分桶(减少 Cypher 分支中过滤与扫描)
|
|
|
|
+ allowed = {"Entity", "Concept", "Topic"}
|
|
|
|
+ seed_entities = sorted({s["name"] for s in seeds if s.get("label") == "Entity"})
|
|
|
|
+ seed_concepts = sorted(
|
|
|
|
+ {s["name"] for s in seeds if s.get("label") == "Concept"}
|
|
|
|
+ )
|
|
|
|
+ seed_topics = sorted({s["name"] for s in seeds if s.get("label") == "Topic"})
|
|
|
|
+ if not (set(x.get("label") for x in seeds) <= allowed):
|
|
|
|
+ raise ValueError("Seed label must be one of Entity/Concept/Topic")
|
|
|
|
+
|
|
|
|
+ async with self.driver.session(database=self.database) as session:
|
|
|
|
+ # 2) 统计跨 seeds 的共现 TopN(当未显式提供 other_names 时)
|
|
|
|
+ if not other_names:
|
|
|
|
+ query_top = """
|
|
|
|
+ // ========== 先收敛到与 seeds 共现的 gc + sname ==========
|
|
|
|
+ CALL {
|
|
|
|
+ WITH $seed_entities AS se
|
|
|
|
+ UNWIND se AS sname
|
|
|
|
+ MATCH (:Entity {name:sname})<-[:HAS_ENTITY]-(gc:GraphChunk)
|
|
|
|
+ RETURN gc, sname
|
|
|
|
+ UNION
|
|
|
|
+ WITH $seed_concepts AS sc
|
|
|
|
+ UNWIND sc AS sname
|
|
|
|
+ MATCH (:Concept {name:sname})<-[:HAS_CONCEPT]-(gc:GraphChunk)
|
|
|
|
+ RETURN gc, sname
|
|
|
|
+ UNION
|
|
|
|
+ WITH $seed_topics AS st
|
|
|
|
+ UNWIND st AS sname
|
|
|
|
+ MATCH (:Topic {name:sname})<-[:HAS_TOPIC]-(gc:GraphChunk)
|
|
|
|
+ RETURN gc, sname
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // ========== 在这些 gc 上一次性发散到 other ==========
|
|
|
|
+ MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(other)
|
|
|
|
+ // 排除把种子本身当作 other 的情况(基于标签分桶,仍然命中索引)
|
|
|
|
+ WHERE NOT (
|
|
|
|
+ (other:Entity AND other.name IN $seed_entities) OR
|
|
|
|
+ (other:Concept AND other.name IN $seed_concepts) OR
|
|
|
|
+ (other:Topic AND other.name IN $seed_topics)
|
|
|
|
+ )
|
|
|
|
+ WITH other.name AS name, gc, sname
|
|
|
|
+ // 跨 gc 汇总频次;跨 sname 汇总支持度(命中的不同 seed 数)
|
|
|
|
+ WITH name,
|
|
|
|
+ count(DISTINCT gc) AS co_freq,
|
|
|
|
+ count(DISTINCT sname) AS seed_support
|
|
|
|
+ WHERE seed_support >= $min_support
|
|
|
|
+ ORDER BY seed_support DESC, co_freq DESC
|
|
|
|
+ LIMIT $top_k
|
|
|
|
+ RETURN name
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ params_top = {
|
|
|
|
+ "seed_entities": seed_entities,
|
|
|
|
+ "seed_concepts": seed_concepts,
|
|
|
|
+ "seed_topics": seed_topics,
|
|
|
|
+ "top_k": top_k,
|
|
|
|
+ "min_support": min_support,
|
|
|
|
+ }
|
|
|
|
+ recs = await session.run(query_top, params_top)
|
|
|
|
+ other_names = [r["name"] async for r in recs]
|
|
|
|
+ if not other_names:
|
|
|
|
+ return []
|
|
|
|
+
|
|
|
|
+ # 3) 回捞阶段:先确定“与 seeds 相连的 gc”,再用 other_names 索引命中
|
|
|
|
+ # 这里用 UNWIND onames + 按标签逐类 MATCH,确保使用 name 索引,
|
|
|
|
+ # 避免 `o.name IN $list` 走全表扫描/低效 plan。
|
|
|
|
+ query_expand = """
|
|
|
|
+ // 先生成候选 gc(包含任一 seed)
|
|
|
|
+ CALL {
|
|
|
|
+ WITH $seed_entities AS se
|
|
|
|
+ UNWIND se AS sname
|
|
|
|
+ MATCH (:Entity {name:sname})<-[:HAS_ENTITY]-(gc:GraphChunk)
|
|
|
|
+ RETURN DISTINCT gc
|
|
|
|
+ UNION
|
|
|
|
+ WITH $seed_concepts AS sc
|
|
|
|
+ UNWIND sc AS sname
|
|
|
|
+ MATCH (:Concept {name:sname})<-[:HAS_CONCEPT]-(gc:GraphChunk)
|
|
|
|
+ RETURN DISTINCT gc
|
|
|
|
+ UNION
|
|
|
|
+ WITH $seed_topics AS st
|
|
|
|
+ UNWIND st AS sname
|
|
|
|
+ MATCH (:Topic {name:sname})<-[:HAS_TOPIC]-(gc:GraphChunk)
|
|
|
|
+ RETURN DISTINCT gc
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // 用索引命中 other(UNWIND 列表,让等值匹配可走索引)
|
|
|
|
+ UNWIND $other_names AS oname
|
|
|
|
+ CALL {
|
|
|
|
+ WITH oname
|
|
|
|
+ MATCH (o:Entity {name:oname}) RETURN o
|
|
|
|
+ UNION
|
|
|
|
+ WITH oname
|
|
|
|
+ MATCH (o:Concept {name:oname}) RETURN o
|
|
|
|
+ UNION
|
|
|
|
+ WITH oname
|
|
|
|
+ MATCH (o:Topic {name:oname}) RETURN o
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(o)
|
|
|
|
+ RETURN DISTINCT gc.milvus_id AS milvus_id
|
|
|
|
+ LIMIT $limit
|
|
|
|
+ """
|
|
|
|
+
|
|
|
|
+ params_expand = {
|
|
|
|
+ "seed_entities": seed_entities,
|
|
|
|
+ "seed_concepts": seed_concepts,
|
|
|
|
+ "seed_topics": seed_topics,
|
|
|
|
+ "other_names": other_names,
|
|
|
|
+ "limit": limit,
|
|
|
|
+ }
|
|
|
|
+ recs2 = await session.run(query_expand, params_expand)
|
|
|
|
+ return [r["milvus_id"] async for r in recs2]
|
|
|
|
|
|
# ========== 路径 Path ==========
|
|
# ========== 路径 Path ==========
|
|
async def shortest_path_chunks(
|
|
async def shortest_path_chunks(
|
|
@@ -76,44 +205,59 @@ class AsyncGraphExpansion:
|
|
WITH p
|
|
WITH p
|
|
UNWIND [n IN nodes(p) WHERE n:GraphChunk | n] AS gc
|
|
UNWIND [n IN nodes(p) WHERE n:GraphChunk | n] AS gc
|
|
RETURN DISTINCT gc.milvus_id AS milvus_id
|
|
RETURN DISTINCT gc.milvus_id AS milvus_id
|
|
- LIMIT $limit
|
|
|
|
|
|
+ LIMIT ${limit}
|
|
"""
|
|
"""
|
|
async with self.driver.session(database=self.database) as session:
|
|
async with self.driver.session(database=self.database) as session:
|
|
records = await session.run(query, {"a_name": a_name, "b_name": b_name})
|
|
records = await session.run(query, {"a_name": a_name, "b_name": b_name})
|
|
- return [r["milvus_id"] async for r in records]
|
|
|
|
|
|
+ return [str(r["milvus_id"]) async for r in records]
|
|
|
|
|
|
# ========== 扩展 Expansion ==========
|
|
# ========== 扩展 Expansion ==========
|
|
async def expand_candidates(
|
|
async def expand_candidates(
|
|
- self, seed_ids: List[str], k_per_relation: int = 200, limit: int = 1000
|
|
|
|
- ) -> List[str]:
|
|
|
|
|
|
+ self, seed_ids: List[int], k_per_relation: int = 200, limit: int = 1000
|
|
|
|
+ ) -> List[Dict[str, Any]]:
|
|
"""
|
|
"""
|
|
基于候选 milvus_id 做 1-hop 扩展(实体/概念/主题),并按权重汇总
|
|
基于候选 milvus_id 做 1-hop 扩展(实体/概念/主题),并按权重汇总
|
|
"""
|
|
"""
|
|
query = """
|
|
query = """
|
|
- MATCH (gc:GraphChunk) WHERE gc.milvus_id IN $seed_ids
|
|
|
|
- // 同实体
|
|
|
|
- MATCH (gc)-[:HAS_ENTITY]->(e)<-[:HAS_ENTITY]-(gc2:GraphChunk)
|
|
|
|
- WHERE gc2 <> gc
|
|
|
|
- WITH DISTINCT gc2, 1.0 AS w
|
|
|
|
- LIMIT $k_per_relation
|
|
|
|
-
|
|
|
|
- UNION
|
|
|
|
-
|
|
|
|
- MATCH (gc:GraphChunk) WHERE gc.milvus_id IN $seed_ids
|
|
|
|
- MATCH (gc)-[:HAS_CONCEPT]->(c)<-[:HAS_CONCEPT]-(gc3:GraphChunk)
|
|
|
|
- WHERE gc3 <> gc
|
|
|
|
- WITH DISTINCT gc3, 0.7 AS w
|
|
|
|
- LIMIT $k_per_relation
|
|
|
|
-
|
|
|
|
- UNION
|
|
|
|
-
|
|
|
|
- MATCH (gc:GraphChunk) WHERE gc.milvus_id IN $seed_ids
|
|
|
|
- MATCH (gc)-[:HAS_TOPIC]->(t)<-[:HAS_TOPIC]-(gc4:GraphChunk)
|
|
|
|
- WHERE gc4 <> gc
|
|
|
|
- WITH DISTINCT gc4, 0.6 AS w
|
|
|
|
- LIMIT $k_per_relation
|
|
|
|
-
|
|
|
|
- RETURN DISTINCT gc2.milvus_id AS milvus_id, sum(w) AS score
|
|
|
|
|
|
+ CALL {
|
|
|
|
+ WITH $seed_ids AS seed_ids, $k_per_relation AS k
|
|
|
|
+
|
|
|
|
+ // 同实体
|
|
|
|
+ MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
|
|
|
|
+ MATCH (gc0)-[:HAS_ENTITY]->(:Entity)<-[:HAS_ENTITY]-(cand:GraphChunk)
|
|
|
|
+ WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
|
|
|
|
+ WITH DISTINCT cand, 1.0 AS w
|
|
|
|
+ LIMIT k
|
|
|
|
+ RETURN cand AS gc, w
|
|
|
|
+
|
|
|
|
+ UNION
|
|
|
|
+
|
|
|
|
+ // 同概念
|
|
|
|
+ WITH seed_ids, k
|
|
|
|
+ MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
|
|
|
|
+ MATCH (gc0)-[:HAS_CONCEPT]->(:Concept)<-[:HAS_CONCEPT]-(cand:GraphChunk)
|
|
|
|
+ WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
|
|
|
|
+ WITH DISTINCT cand, 0.7 AS w
|
|
|
|
+ LIMIT k
|
|
|
|
+ RETURN cand AS gc, w
|
|
|
|
+
|
|
|
|
+ UNION
|
|
|
|
+
|
|
|
|
+ // 同主题
|
|
|
|
+ WITH seed_ids, k
|
|
|
|
+ MATCH (gc0:GraphChunk) WHERE gc0.milvus_id IN seed_ids
|
|
|
|
+ MATCH (gc0)-[:HAS_TOPIC]->(:Topic)<-[:HAS_TOPIC]-(cand:GraphChunk)
|
|
|
|
+ WHERE cand <> gc0 AND NOT cand.milvus_id IN seed_ids
|
|
|
|
+ WITH DISTINCT cand, 0.6 AS w
|
|
|
|
+ LIMIT k
|
|
|
|
+ RETURN cand AS gc, w
|
|
|
|
+ }
|
|
|
|
+ WITH gc, sum(w) AS score
|
|
|
|
+ RETURN
|
|
|
|
+ gc.milvus_id AS milvus_id,
|
|
|
|
+ gc.chunk_id AS chunk_id, // 若你的属性名不同,请改成实际字段
|
|
|
|
+ gc.doc_id AS doc_id, // 同上
|
|
|
|
+ score
|
|
ORDER BY score DESC
|
|
ORDER BY score DESC
|
|
LIMIT $limit
|
|
LIMIT $limit
|
|
"""
|
|
"""
|
|
@@ -126,4 +270,12 @@ class AsyncGraphExpansion:
|
|
"limit": limit,
|
|
"limit": limit,
|
|
},
|
|
},
|
|
)
|
|
)
|
|
- return [r["milvus_id"] async for r in records]
|
|
|
|
|
|
+ return [
|
|
|
|
+ {
|
|
|
|
+ "milvus_id": r["milvus_id"],
|
|
|
|
+ "chunk_id": r["chunk_id"],
|
|
|
|
+ "doc_id": r["doc_id"],
|
|
|
|
+ "score": r["score"],
|
|
|
|
+ }
|
|
|
|
+ async for r in records
|
|
|
|
+ ]
|