Browse Source

新增知识图谱模块

luojunhui 2 tuần trước cách đây
mục cha
commit
297a3e4fd0

+ 2 - 1
applications/async_task/__init__.py

@@ -1,6 +1,7 @@
 from .chunk_task import ChunkEmbeddingTask
 from .delete_task import DeleteTask
 from .auto_rechunk_task import AutoRechunkTask
+from .build_graph import BuildGraph
 
 
-__all__ = ["ChunkEmbeddingTask", "DeleteTask", "AutoRechunkTask"]
+__all__ = ["ChunkEmbeddingTask", "DeleteTask", "AutoRechunkTask", "BuildGraph"]

+ 55 - 0
applications/async_task/build_graph.py

@@ -0,0 +1,55 @@
+"""
+use neo4j to build graph
+"""
+import json
+import random
+from dataclasses import fields
+
+from applications.utils.neo4j import AsyncNeo4jRepository
+from applications.utils.neo4j import Document, GraphChunk, ChunkRelations
+from applications.utils.async_utils import run_tasks_with_asyncio_task_group
+
+
+class BuildGraph(AsyncNeo4jRepository):
+    def __init__(self, neo4j, mysql_client):
+        super().__init__(neo4j)
+        self.mysql_client = mysql_client
+
+    @staticmethod
+    def from_dict(cls, data: dict):
+        field_names = {f.name for f in fields(cls)}
+        return cls(**{k: v for k, v in data.items() if k in field_names})
+
+    async def add_single_chunk(self, param):
+        """async process single chunk"""
+        param["milvus_id"] = random.randint(100000, 999999)
+        doc: Document = self.from_dict(Document, param)
+        graph_chunk: GraphChunk = self.from_dict(GraphChunk, param)
+        relations: ChunkRelations = self.from_dict(ChunkRelations, param)
+        await self.add_document_with_chunk(doc, graph_chunk, relations)
+
+    async def get_chunk_list(self, doc_id):
+        """async get chunk list"""
+        query = """
+            SELECT chunk_id, doc_id, topic, domain, task_type, keywords, concepts, entities, text_type, dataset_id
+            FROM content_chunks
+            WHERE embedding_status = %s AND status = %s and doc_id = %s;
+        """
+        response = await self.mysql_client.async_fetch(
+            query=query,
+            params=(2, 1, doc_id)
+        )
+        L = []
+        for i in response:
+            i["keywords"] = json.loads(i["keywords"])
+            i["entities"] = json.loads(i["entities"])
+            i["concepts"] = json.loads(i["concepts"])
+            L.append(i)
+        return L
+
+    async def deal(self, doc_id):
+        for task in await self.get_chunk_list(doc_id):
+            await self.add_single_chunk(task)
+
+
+

+ 2 - 0
applications/config/__init__.py

@@ -9,6 +9,7 @@ from .base_chunk import Chunk, ChunkerConfig
 from .elastic_search_config import ELASTIC_SEARCH_INDEX, ES_HOSTS, ES_PASSWORD
 from .milvus_config import MILVUS_CONFIG, BASE_MILVUS_SEARCH_PARAMS
 from .mysql_config import RAG_MYSQL_CONFIG
+from .neo4j_config import NEO4j_CONFIG
 from .weight_config import WEIGHT_MAP
 
 
@@ -28,4 +29,5 @@ __all__ = [
     "ES_PASSWORD",
     "ELASTIC_SEARCH_INDEX",
     "BASE_MILVUS_SEARCH_PARAMS",
+    "NEO4j_CONFIG"
 ]

+ 1 - 0
applications/config/neo4j_config.py

@@ -0,0 +1 @@
+NEO4j_CONFIG = {"url": "bolt://192.168.100.31:7687", "user": "neo4j", "password": "ljh000118"}

+ 9 - 0
applications/resource/resource_manager.py

@@ -1,5 +1,7 @@
 from pymilvus import connections, CollectionSchema, Collection
+from neo4j import AsyncGraphDatabase
 
+from applications.config import NEO4j_CONFIG
 from applications.utils.mysql import DatabaseManager
 from applications.utils.milvus.field import fields
 from applications.utils.elastic_search import AsyncElasticSearchClient
@@ -15,6 +17,7 @@ class ResourceManager:
         self.es_client: AsyncElasticSearchClient | None = None
         self.milvus_client: Collection | None = None
         self.mysql_client: DatabaseManager | None = None
+        self.graph_client = None
 
     async def load_milvus(self):
         connections.connect("default", **self.milvus_config)
@@ -54,6 +57,10 @@ class ResourceManager:
         await self.load_milvus()
         print("✅ Milvus loaded")
 
+        uri: str = NEO4j_CONFIG["url"]
+        auth: tuple = NEO4j_CONFIG["user"], NEO4j_CONFIG["password"]
+        self.graph_client = AsyncGraphDatabase.driver(uri=uri, auth=auth)
+
     async def shutdown(self):
         # 关闭 Elasticsearch
         if self.es_client:
@@ -69,6 +76,8 @@ class ResourceManager:
             await self.mysql_client.close_pools()
             print("Mysql closed")
 
+        await self.graph_client.close()
+
 
 _resource_manager: ResourceManager | None = None
 

+ 9 - 0
applications/utils/neo4j/__init__.py

@@ -0,0 +1,9 @@
+from .repository import AsyncNeo4jRepository
+from .models import Document, ChunkRelations, GraphChunk
+
+__all__ = [
+    "AsyncNeo4jRepository",
+    "Document",
+    "ChunkRelations",
+    "GraphChunk",
+]

+ 95 - 0
applications/utils/neo4j/models.py

@@ -0,0 +1,95 @@
+from dataclasses import dataclass
+from typing import List
+
+
+@dataclass
+class Document:
+    doc_id: str
+    dataset_id: int
+
+
+@dataclass
+class GraphChunk:
+    milvus_id: int
+    chunk_id: int
+    doc_id: str
+    topic: str
+    domain: str
+    text_type: int
+    task_type: str
+
+
+@dataclass
+class ChunkRelations:
+    entities: List[str]
+    concepts: List[str]
+    keywords: List[str]
+    domain: str
+    topic: str
+
+
+QUERY = """
+// 1) Document & GraphChunk
+MERGE (d:Document {doc_id: $doc_id})
+  ON CREATE SET d.dataset_id = $dataset_id
+  SET d.dataset_id = $dataset_id
+
+MERGE (gc:GraphChunk {milvus_id: $milvus_id})
+  ON CREATE SET gc.chunk_id = $chunk_id, gc.doc_id = $doc_id
+  SET gc.topic     = $topic,
+      gc.domain    = $domain,
+      gc.text_type = $text_type,
+      gc.task_type = $task_type,
+      gc.doc_id    = $doc_id
+
+MERGE (gc)-[:BELONGS_TO]->(d)
+MERGE (d)-[:HAS_CHUNK]->(gc)
+
+// 2) 参数准备
+WITH gc,
+     COALESCE($entities,  []) AS entities,
+     COALESCE($concepts,  []) AS concepts,
+     COALESCE($keywords,  []) AS keywords,
+     $domain_name AS domain_name,
+     $topic_name  AS topic_name
+
+// 3) Entities
+UNWIND entities AS e_name
+  WITH gc, e_name, concepts, keywords, domain_name, topic_name
+  WITH gc, TRIM(e_name) AS e_name, concepts, keywords, domain_name, topic_name
+  WHERE e_name <> ""
+  MERGE (e:Entity {name: e_name})
+  MERGE (gc)-[:HAS_ENTITY]->(e)
+
+// 4) Concepts
+WITH gc, concepts, keywords, domain_name, topic_name
+UNWIND concepts AS c_name
+  WITH gc, c_name, keywords, domain_name, topic_name
+  WITH gc, TRIM(c_name) AS c_name, keywords, domain_name, topic_name
+  WHERE c_name <> ""
+  MERGE (co:Concept {name: c_name})
+  MERGE (gc)-[:HAS_CONCEPT]->(co)
+
+// 5) Keywords
+WITH gc, keywords, domain_name, topic_name
+UNWIND keywords AS k_name
+  WITH gc, k_name, domain_name, topic_name
+  WITH gc, TRIM(k_name) AS k_name, domain_name, topic_name
+  WHERE k_name <> ""
+  MERGE (k:Keyword {name: k_name})
+  MERGE (gc)-[:HAS_KEYWORD]->(k)
+
+// 6) Domain(条件执行,用 FOREACH 替代 CALL)
+WITH gc, domain_name, topic_name
+FOREACH (_ IN CASE WHEN domain_name IS NOT NULL AND TRIM(domain_name) <> "" THEN [1] ELSE [] END |
+  MERGE (d_node:Domain {name: TRIM(domain_name)})
+  MERGE (gc)-[:HAS_DOMAIN]->(d_node)
+)
+
+// 7) Topic(条件执行,用 FOREACH 替代 CALL)
+WITH gc, topic_name
+FOREACH (_ IN CASE WHEN topic_name IS NOT NULL AND TRIM(topic_name) <> "" THEN [1] ELSE [] END |
+  MERGE (t:Topic {name: TRIM(topic_name)})
+  MERGE (gc)-[:HAS_TOPIC]->(t)
+)
+"""

+ 28 - 0
applications/utils/neo4j/query.py

@@ -0,0 +1,28 @@
+class AsyncNeo4jQuery:
+    def __init__(self, neo4j):
+        self.neo4j = neo4j
+
+    async def close(self):
+        await self.neo4j.close()
+
+    async def get_document_by_id(self, doc_id: str):
+        query = """
+        MATCH (d:Document {doc_id: $doc_id})
+        OPTIONAL MATCH (d)-[:HAS_CHUNK]->(c:Chunk)
+        RETURN d, collect(c) as chunks
+        """
+        async with self.neo4j.session() as session:
+            result = await session.run(query, doc_id=doc_id)
+            return [
+                record.data() for record in await result.consume().records
+            ]  # 注意 result 需要 async 迭代
+
+    async def search_chunks_by_topic(self, topic: str):
+        query = """
+        MATCH (c:Chunk {topic: $topic})
+        OPTIONAL MATCH (c)-[:HAS_ENTITY]->(e:Entity)
+        RETURN c, collect(e.name) as entities
+        """
+        async with self.neo4j.session() as session:
+            result = await session.run(query, topic=topic)
+            return [record.data() async for record in result]

+ 28 - 0
applications/utils/neo4j/repository.py

@@ -0,0 +1,28 @@
+from .models import Document, GraphChunk, ChunkRelations, QUERY
+
+
+class AsyncNeo4jRepository:
+    def __init__(self, neo4j):
+        self.neo4j = neo4j
+
+    async def add_document_with_chunk(
+        self, doc: Document, chunk: GraphChunk, relations: ChunkRelations
+    ):
+        async with self.neo4j.session() as session:
+            await session.run(
+                QUERY,
+                milvus_id=chunk.milvus_id,
+                doc_id=doc.doc_id,
+                dataset_id=doc.dataset_id,
+                chunk_id=chunk.chunk_id,
+                topic=chunk.topic,
+                domain=chunk.domain,
+                text_type=chunk.text_type,
+                task_type=chunk.task_type,
+                entities=relations.entities,
+                concepts=relations.concepts,
+                keywords=relations.keywords,
+                domain_name=relations.domain,
+                topic_name=relations.topic,
+            )
+        print(f"✅ {doc.doc_id} - {chunk.chunk_id} 已写入 Neo4j (async)")

+ 2 - 1
requirements.txt

@@ -20,4 +20,5 @@ quart-cors==0.8.0
 tiktoken==0.11.0
 uvloop==0.21.0
 elasticsearch==8.17.2
-scikit-learn==1.7.2
+scikit-learn==1.7.2
+neo4j==5.28.2

+ 14 - 1
routes/buleprint.py

@@ -9,7 +9,7 @@ from quart_cors import cors
 
 from applications.api import get_basic_embedding
 from applications.api import get_img_embedding
-from applications.async_task import AutoRechunkTask
+from applications.async_task import AutoRechunkTask, BuildGraph
 from applications.async_task import ChunkEmbeddingTask, DeleteTask
 from applications.config import (
     DEFAULT_MODEL,
@@ -538,3 +538,16 @@ async def auto_rechunk():
     auto_rechunk_task = AutoRechunkTask(mysql_client=resource.mysql_client)
     process_cnt = await auto_rechunk_task.deal()
     return jsonify({"status_code": 200, "detail": "success", "cnt": process_cnt})
+
+
+@server_bp.route("/build_graph", methods=["POST"])
+async def delete_task():
+    body = await request.get_json()
+    doc_id: str = body.get("doc_id")
+    if not doc_id:
+        return jsonify({"status_code": 500, "detail": "docId not found", "data": {}})
+
+    resource = get_resource_manager()
+    build_graph_task = BuildGraph(neo4j=resource.graph_client, mysql_client=resource.mysql_client)
+    await build_graph_task.deal(doc_id)
+    return jsonify({"status_code": 200, "detail": "success", "data": {}})