Ver Fonte

开发 graph_expansion

luojunhui há 1 semana atrás
pai
commit
1c354b7486
1 ficheiros alterados com 36 adições e 26 exclusões
  1. 36 26
      applications/utils/neo4j/graph_expansion.py

+ 36 - 26
applications/utils/neo4j/graph_expansion.py

@@ -1,54 +1,64 @@
 import json
-from typing import List, Optional, Dict, Any
+from typing import List, Optional, Dict, Any, Iterable
 
 
 class AsyncGraphExpansion:
+    _REL_ANY = ":HAS_ENTITY|HAS_CONCEPT|HAS_TOPIC"
+
     def __init__(self, driver, database: str = "neo4j"):
         self.driver = driver
         self.database = database
 
     # ========== 共现 co-occurrence ==========
     async def co_occurrence(
-        self,
-        seed_name: str,
-        seed_label: str,
-        other_names: Optional[List[str]] = None,
-        limit: int = 500,
+            self,
+            seed_name: str,
+            seed_label: str,
+            other_names: Optional[List[str]] = None,
+            limit: int = 500,
     ) -> List[str]:
         """
         找与种子要素共现的 chunk
-        :param seed_name: 种子名称
-        :param seed_label: 种子标签 ('Entity', 'Concept', 'Topic')
-        :param other_names: 指定要扩展的其它要素名称列表(若为 None,先查共现TopN再扩展)
-        :param limit: 返回数量上限
         """
+        REL_ANY = ":HAS_ENTITY|HAS_CONCEPT|HAS_TOPIC"  # 非弃用:只在第一个关系类型前加冒号
+
         async with self.driver.session(database=self.database) as session:
             if not other_names:
-                # 先统计高频共现要素
-                query_top = f"""
+                # 一次往返:显式作用域子查询 + 非弃用关系“或”写法
+                cypher = f"""
                 MATCH (seed:`{seed_label}` {{name:$seed_name}})
-                      <-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]-(gc:GraphChunk)
-                MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(other)
-                WHERE other <> seed
-                RETURN other.name AS name, count(*) AS co_freq
-                ORDER BY co_freq DESC
-                LIMIT 20
+                CALL (seed) {{
+                    MATCH (seed)<-[{REL_ANY}]-(gc:GraphChunk)
+                    MATCH (gc)-[{REL_ANY}]->(other)
+                    WHERE other <> seed
+                    WITH other.name AS name, count(*) AS co_freq
+                    ORDER BY co_freq DESC
+                    RETURN collect(name)[..20] AS sel_names
+                }}
+                WITH seed, sel_names
+                WHERE size(sel_names) > 0
+                MATCH (seed)<-[{REL_ANY}]-(gc2:GraphChunk)
+                MATCH (gc2)-[{REL_ANY}]->(o2)
+                WHERE o2.name IN sel_names
+                RETURN DISTINCT gc2.milvus_id AS milvus_id
+                LIMIT $limit
                 """
-                records = await session.run(query_top, {"seed_name": seed_name})
-                other_names = [r["name"] async for r in records]
-                print(other_names)
+                records = await session.run(
+                    cypher, {"seed_name": seed_name, "limit": limit}
+                )
+                return [str(r["milvus_id"]) async for r in records]
 
-            # 根据共现要素回捞 chunk
-            query_expand = f"""
+            # 已提供 other_names:直接扩展(同样使用非弃用关系“或”写法)
+            cypher = f"""
             MATCH (seed:`{seed_label}` {{name:$seed_name}})
-                  <-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]-(gc:GraphChunk)
-            MATCH (gc)-[:HAS_ENTITY|:HAS_CONCEPT|:HAS_TOPIC]->(o)
+                  <-[{REL_ANY}]-(gc:GraphChunk)
+            MATCH (gc)-[{REL_ANY}]->(o)
             WHERE o.name IN $other_names
             RETURN DISTINCT gc.milvus_id AS milvus_id
             LIMIT $limit
             """
             records = await session.run(
-                query_expand,
+                cypher,
                 {"seed_name": seed_name, "other_names": other_names, "limit": limit},
             )
             return [str(r["milvus_id"]) async for r in records]