Ver Fonte

Merge branch 'feature/luojunhui/2025-09-24-add-graph' of Server/rag_server into master

luojunhui há 2 semanas atrás
pai
commit
6e89cce338

+ 2 - 1
applications/async_task/__init__.py

@@ -1,6 +1,7 @@
 from .chunk_task import ChunkEmbeddingTask
 from .delete_task import DeleteTask
 from .auto_rechunk_task import AutoRechunkTask
+from .build_graph import BuildGraph
 
 
-__all__ = ["ChunkEmbeddingTask", "DeleteTask", "AutoRechunkTask"]
+__all__ = ["ChunkEmbeddingTask", "DeleteTask", "AutoRechunkTask", "BuildGraph"]

+ 77 - 0
applications/async_task/build_graph.py

@@ -0,0 +1,77 @@
+"""
+use neo4j to build graph
+"""
+
+from dataclasses import fields
+
+from applications.utils.neo4j import AsyncNeo4jRepository
+from applications.utils.neo4j import Document, GraphChunk, ChunkRelations
+from applications.utils.mysql import ContentChunks
+from applications.utils.async_utils import run_tasks_with_asyncio_task_group
+
+
+class BuildGraph(AsyncNeo4jRepository):
+    INIT_STATUS = 0
+    PROCESSING_STATUS = 1
+    FINISHED_STATUS = 2
+    FAILED_STATUS = 3
+
+    def __init__(self, neo4j, es_client, mysql_client):
+        super().__init__(neo4j)
+        self.es_client = es_client
+        self.chunk_manager = ContentChunks(mysql_client)
+
+    @staticmethod
+    def from_dict(cls, data: dict):
+        field_names = {f.name for f in fields(cls)}
+        return cls(**{k: v for k, v in data.items() if k in field_names})
+
+    async def add_single_chunk(self, param):
+        """async process single chunk"""
+        chunk_id = param["chunk_id"]
+        doc_id = param["doc_id"]
+        acquire_lock = await self.chunk_manager.update_graph_status(
+            doc_id, chunk_id, self.INIT_STATUS, self.PROCESSING_STATUS
+        )
+        if acquire_lock:
+            print(f"while building graph, acquire lock for chunk {chunk_id}")
+            return
+
+        try:
+            doc: Document = self.from_dict(Document, param)
+            graph_chunk: GraphChunk = self.from_dict(GraphChunk, param)
+            relations: ChunkRelations = self.from_dict(ChunkRelations, param)
+
+            await self.add_document_with_chunk(doc, graph_chunk, relations)
+            await self.chunk_manager.update_graph_status(
+                doc_id, chunk_id, self.PROCESSING_STATUS, self.FINISHED_STATUS
+            )
+        except Exception as e:
+            print(f"failed to build graph for chunk {chunk_id}: {e}")
+            await self.chunk_manager.update_graph_status(
+                doc_id, chunk_id, self.PROCESSING_STATUS, self.FAILED_STATUS
+            )
+
+    async def get_chunk_list_from_es(self, doc_id):
+        """async get chunk list"""
+        query = {
+            "query": {"bool": {"must": [{"term": {"doc_id": doc_id}}]}},
+            "_source": True,
+        }
+        try:
+            resp = await self.es_client.async_search(query=query)
+            return [hit["_source"] for hit in resp["hits"]["hits"]]
+        except Exception as e:
+            print(f"search failed: {e}")
+            return []
+
+    async def deal(self, doc_id):
+        """async process single chunk"""
+        chunk_list = await self.get_chunk_list_from_es(doc_id)
+        await run_tasks_with_asyncio_task_group(
+            task_list=chunk_list,
+            handler=self.add_single_chunk,
+            description="build graph",
+            unit="chunk",
+            max_concurrency=10,
+        )

+ 21 - 6
applications/async_task/delete_task.py

@@ -6,6 +6,9 @@ from applications.utils.milvus import async_delete_chunk
 
 
 class DeleteTask:
+    USEFUL_STATUS = 1
+    DELETE_STATUS = 0
+
     def __init__(self, resource):
         self.mysql_client = resource.mysql_client
         self.milvus_client = resource.milvus_client
@@ -65,7 +68,9 @@ class DeleteTask:
         chunk_id = params["chunk_id"]
         try:
             self.chunk_manager = ContentChunks(self.mysql_client)
-            await self.chunk_manager.update_doc_chunk_status(doc_id, chunk_id, 1, 0)
+            await self.chunk_manager.update_doc_chunk_status(
+                doc_id, chunk_id, self.USEFUL_STATUS, self.DELETE_STATUS
+            )
             await self.delete_by_filters({"doc_id": doc_id, "chunk_id": chunk_id})
             return {"doc_id": doc_id, "chunk_id": chunk_id, "status": "success"}
         except Exception as e:
@@ -77,8 +82,12 @@ class DeleteTask:
         try:
             self.chunk_manager = ContentChunks(self.mysql_client)
             self.content_manager = Contents(self.mysql_client)
-            await self.chunk_manager.update_doc_status(doc_id, 1, 0)
-            await self.content_manager.update_doc_status(doc_id, 1, 0)
+            await self.chunk_manager.update_doc_status(
+                doc_id, self.USEFUL_STATUS, self.DELETE_STATUS
+            )
+            await self.content_manager.update_doc_status(
+                doc_id, self.USEFUL_STATUS, self.DELETE_STATUS
+            )
 
             await self.delete_by_filters({"doc_id": doc_id})
             return {"doc_id": doc_id, "status": "success"}
@@ -92,9 +101,15 @@ class DeleteTask:
             self.chunk_manager = ContentChunks(self.mysql_client)
             self.content_manager = Contents(self.mysql_client)
             self.dataset_manager = Dataset(self.mysql_client)
-            await self.chunk_manager.update_dataset_status(dataset_id, 1, 0)
-            await self.content_manager.update_dataset_status(dataset_id, 1, 0)
-            await self.dataset_manager.update_dataset_status(dataset_id, 1, 0)
+            await self.chunk_manager.update_dataset_status(
+                dataset_id, self.USEFUL_STATUS, self.DELETE_STATUS
+            )
+            await self.content_manager.update_dataset_status(
+                dataset_id, self.USEFUL_STATUS, self.DELETE_STATUS
+            )
+            await self.dataset_manager.update_dataset_status(
+                dataset_id, self.USEFUL_STATUS, self.DELETE_STATUS
+            )
 
             await self.delete_by_filters({"dataset_id": dataset_id})
             return {"dataset_id": dataset_id, "status": "success"}

+ 2 - 0
applications/config/__init__.py

@@ -9,6 +9,7 @@ from .base_chunk import Chunk, ChunkerConfig
 from .elastic_search_config import ELASTIC_SEARCH_INDEX, ES_HOSTS, ES_PASSWORD
 from .milvus_config import MILVUS_CONFIG, BASE_MILVUS_SEARCH_PARAMS
 from .mysql_config import RAG_MYSQL_CONFIG
+from .neo4j_config import NEO4j_CONFIG
 from .weight_config import WEIGHT_MAP
 
 
@@ -28,4 +29,5 @@ __all__ = [
     "ES_PASSWORD",
     "ELASTIC_SEARCH_INDEX",
     "BASE_MILVUS_SEARCH_PARAMS",
+    "NEO4j_CONFIG",
 ]

+ 5 - 0
applications/config/neo4j_config.py

@@ -0,0 +1,5 @@
+NEO4j_CONFIG = {
+    "url": "bolt://192.168.100.31:7687",
+    "user": "neo4j",
+    "password": "ljh000118",
+}

+ 11 - 0
applications/resource/resource_manager.py

@@ -1,5 +1,7 @@
 from pymilvus import connections, CollectionSchema, Collection
+from neo4j import AsyncGraphDatabase, AsyncDriver
 
+from applications.config import NEO4j_CONFIG
 from applications.utils.mysql import DatabaseManager
 from applications.utils.milvus.field import fields
 from applications.utils.elastic_search import AsyncElasticSearchClient
@@ -15,6 +17,7 @@ class ResourceManager:
         self.es_client: AsyncElasticSearchClient | None = None
         self.milvus_client: Collection | None = None
         self.mysql_client: DatabaseManager | None = None
+        self.graph_client: AsyncDriver | None = None
 
     async def load_milvus(self):
         connections.connect("default", **self.milvus_config)
@@ -54,6 +57,11 @@ class ResourceManager:
         await self.load_milvus()
         print("✅ Milvus loaded")
 
+        uri: str = NEO4j_CONFIG["url"]
+        auth: tuple = NEO4j_CONFIG["user"], NEO4j_CONFIG["password"]
+        self.graph_client = AsyncGraphDatabase.driver(uri=uri, auth=auth)
+        print("✅ NEO4j loaded")
+
     async def shutdown(self):
         # 关闭 Elasticsearch
         if self.es_client:
@@ -69,6 +77,9 @@ class ResourceManager:
             await self.mysql_client.close_pools()
             print("Mysql closed")
 
+        await self.graph_client.close()
+        print("Graph closed")
+
 
 _resource_manager: ResourceManager | None = None
 

+ 3 - 1
applications/utils/mysql/__init__.py

@@ -1,5 +1,7 @@
 from .pool import DatabaseManager
-from .mapper import Contents, ContentChunks, Dataset, ChatResult
+from .mapper import Dataset, ChatResult
+from .content_chunks import ContentChunks
+from .contents import Contents
 
 
 __all__ = ["Contents", "ContentChunks", "DatabaseManager", "Dataset", "ChatResult"]

+ 12 - 0
applications/utils/mysql/base.py

@@ -0,0 +1,12 @@
+class TaskConst:
+    INIT_STATUS = 0
+    PROCESSING_STATUS = 1
+    FINISHED_STATUS = 2
+    FAILED_STATUS = 3
+
+    CHUNK_USEFUL_STATUS = 1
+
+
+class BaseMySQLClient(TaskConst):
+    def __init__(self, pool):
+        self.pool = pool

+ 190 - 0
applications/utils/mysql/content_chunks.py

@@ -0,0 +1,190 @@
+import json
+
+from applications.config import Chunk
+from .base import BaseMySQLClient
+
+
+class ContentChunks(BaseMySQLClient):
+    async def insert_chunk(self, chunk: Chunk) -> int:
+        query = """
+            INSERT IGNORE INTO content_chunks
+                (chunk_id, doc_id, text, tokens, topic_purity, text_type, dataset_id, status) 
+                VALUES (%s, %s, %s, %s, %s, %s, %s, %s);
+        """
+        return await self.pool.async_save(
+            query=query,
+            params=(
+                chunk.chunk_id,
+                chunk.doc_id,
+                chunk.text,
+                chunk.tokens,
+                chunk.topic_purity,
+                chunk.text_type,
+                chunk.dataset_id,
+                chunk.status,
+            ),
+        )
+
+    # 修改单个 chunk 的分块状态
+    async def update_chunk_status(self, doc_id, chunk_id, ori_status, new_status):
+        query = """
+            UPDATE content_chunks
+            SET chunk_status = %s 
+            WHERE doc_id = %s AND chunk_id = %s AND chunk_status = %s and status = %s;
+        """
+        return await self.pool.async_save(
+            query=query,
+            params=(new_status, doc_id, chunk_id, ori_status, self.CHUNK_USEFUL_STATUS),
+        )
+
+    # 修改单个 chunk 的 embedding 状态
+    async def update_embedding_status(self, doc_id, chunk_id, ori_status, new_status):
+        query = """
+            UPDATE content_chunks
+            SET embedding_status = %s 
+            WHERE doc_id = %s AND chunk_id = %s AND embedding_status = %s;
+        """
+        return await self.pool.async_save(
+            query=query, params=(new_status, doc_id, chunk_id, ori_status)
+        )
+
+    # 设置分块结果,并且将分块状态设置为成功
+    async def set_chunk_result(self, chunk: Chunk, ori_status, new_status):
+        query = """
+            UPDATE content_chunks
+            SET summary = %s, topic = %s, domain = %s, task_type = %s, concepts = %s, 
+                keywords = %s, questions = %s, chunk_status = %s, entities = %s
+            WHERE doc_id = %s AND chunk_id = %s AND chunk_status = %s;
+        """
+        return await self.pool.async_save(
+            query=query,
+            params=(
+                chunk.summary,
+                chunk.topic,
+                chunk.domain,
+                chunk.task_type,
+                json.dumps(chunk.concepts),
+                json.dumps(chunk.keywords),
+                json.dumps(chunk.questions),
+                new_status,
+                json.dumps(chunk.entities),
+                chunk.doc_id,
+                chunk.chunk_id,
+                ori_status,
+            ),
+        )
+
+    # 修改添加至 es 的状态
+    async def update_es_status(self, doc_id, chunk_id, ori_status, new_status):
+        query = """
+            UPDATE content_chunks SET es_status = %s
+            WHERE doc_id = %s AND chunk_id = %s AND es_status = %s;
+        """
+        return await self.pool.async_save(
+            query=query, params=(new_status, doc_id, chunk_id, ori_status)
+        )
+
+    # 修改单个 chunk 的可用状态
+    async def update_doc_chunk_status(self, doc_id, chunk_id, ori_status, new_status):
+        query = """
+            UPDATE content_chunks set status = %s 
+            WHERE doc_id = %s AND chunk_id = %s AND status = %s;
+        """
+        return await self.pool.async_save(
+            query=query, params=(new_status, doc_id, chunk_id, ori_status)
+        )
+
+    # 修改单个 doc 的可用状态
+    async def update_doc_status(self, doc_id, ori_status, new_status):
+        query = """
+            UPDATE content_chunks set status = %s 
+            WHERE doc_id = %s AND status = %s;
+        """
+        return await self.pool.async_save(
+            query=query, params=(new_status, doc_id, ori_status)
+        )
+
+    # 修改 dataset 的可用状态
+    async def update_dataset_status(self, dataset_id, ori_status, new_status):
+        query = """
+            UPDATE content_chunks set status = %s 
+            WHERE dataset_id = %s AND status = %s;
+        """
+        return await self.pool.async_save(
+            query=query, params=(new_status, dataset_id, ori_status)
+        )
+
+    # 修改建立图谱状态
+    async def update_graph_status(self, doc_id, chunk_id, ori_status, new_status):
+        query = """
+            UPDATE content_chunks SET graph_status = %s
+            WHERE doc_id = %s AND chunk_id = %s AND graph_status = %s;
+        """
+        return await self.pool.async_save(
+            query=query, params=(new_status, doc_id, chunk_id, ori_status)
+        )
+
+    async def select_chunk_content(self, doc_id, chunk_id):
+        query = """
+            SELECT * FROM content_chunks WHERE doc_id = %s AND chunk_id = %s;
+        """
+        return await self.pool.async_fetch(query=query, params=(doc_id, chunk_id))
+
+    async def select_chunk_contents(
+        self,
+        page_num: int,
+        page_size: int,
+        order_by=None,
+        doc_id: str = None,
+        doc_status: int = None,
+    ):
+        if order_by is None:
+            order_by = {"chunk_id": "asc"}
+        offset = (page_num - 1) * page_size
+
+        # 动态拼接 where 条件
+        where_clauses = []
+        params = []
+
+        if doc_id:
+            where_clauses.append("doc_id = %s")
+            params.append(doc_id)
+
+        if doc_status:
+            where_clauses.append("doc_status = %s")
+            params.append(doc_status)
+
+        where_sql = " AND ".join(where_clauses)
+
+        # 动态拼接 order by
+        order_field, order_direction = list(order_by.items())[0]
+        order_sql = f"ORDER BY {order_field} {order_direction.upper()}"
+
+        # 查询总数
+        count_query = (
+            f"SELECT COUNT(*) as total_count FROM content_chunks WHERE {where_sql};"
+        )
+        count_result = await self.pool.async_fetch(
+            query=count_query, params=tuple(params)
+        )
+        total_count = count_result[0]["total_count"] if count_result else 0
+
+        # 查询分页数据
+        query = f"""
+            SELECT * FROM content_chunks
+            WHERE {where_sql}
+            {order_sql}
+            LIMIT %s OFFSET %s;
+        """
+        params.extend([page_size, offset])
+        entities = await self.pool.async_fetch(query=query, params=tuple(params))
+
+        total_pages = (total_count + page_size - 1) // page_size  # 向上取整
+        print(total_pages)
+        return {
+            "entities": entities,
+            "total_count": total_count,
+            "page": page_num,
+            "page_size": page_size,
+            "total_pages": total_pages,
+        }

+ 131 - 0
applications/utils/mysql/contents.py

@@ -0,0 +1,131 @@
+from .base import BaseMySQLClient
+
+class Contents(BaseMySQLClient):
+    async def insert_content(self, doc_id, text, text_type, title, dataset_id):
+        query = """
+            INSERT IGNORE INTO contents
+                (doc_id, text, text_type, title, dataset_id)
+            VALUES (%s, %s, %s, %s, %s);
+        """
+        return await self.pool.async_save(
+            query=query, params=(doc_id, text, text_type, title, dataset_id)
+        )
+
+    async def update_content_info(self, doc_id, text, text_type, title, dataset_id):
+        query = """
+            UPDATE contents 
+            SET text = %s, text_type = %s, title = %s, dataset_id = %s, status = %s
+            WHERE doc_id = %s;
+        """
+        return await self.pool.async_save(
+            query=query,
+            params=(text, text_type, title, dataset_id, self.INIT_STATUS, doc_id),
+        )
+
+    async def update_content_status(self, doc_id, ori_status, new_status):
+        query = """
+            UPDATE contents
+            SET status = %s
+            WHERE doc_id = %s AND status = %s;
+        """
+        return await self.pool.async_save(
+            query=query, params=(new_status, doc_id, ori_status)
+        )
+
+    async def update_dataset_status(self, dataset_id, ori_status, new_status):
+        query = """
+            UPDATE contents
+            SET status = %s
+            WHERE dataset_id = %s AND status = %s;
+        """
+        return await self.pool.async_save(
+            query=query, params=(new_status, dataset_id, ori_status)
+        )
+
+    async def update_doc_status(self, doc_id, ori_status, new_status):
+        """
+        this function is to change the using status of each document
+        :param doc_id:
+        :param ori_status:
+        :param new_status:
+        :return:
+        """
+        query = """
+            UPDATE contents SET doc_status = %s WHERE doc_id = %s AND doc_status = %s;
+        """
+        return await self.pool.async_save(
+            query=query, params=(new_status, doc_id, ori_status)
+        )
+
+    async def select_count(self, dataset_id, doc_status=1):
+        query = """
+            SELECT count(*) AS count FROM contents WHERE dataset_id = %s AND doc_status = %s;
+        """
+        rows = await self.pool.async_fetch(query=query, params=(dataset_id, doc_status))
+        return rows[0]["count"] if rows else 0
+
+    async def select_content_by_doc_id(self, doc_id):
+        query = """SELECT * FROM contents WHERE doc_id = %s;"""
+        return await self.pool.async_fetch(query=query, params=(doc_id,))
+
+    async def select_contents(
+        self,
+        page_num: int,
+        page_size: int,
+        order_by=None,
+        dataset_id: int = None,
+        doc_status: int = 1,
+    ):
+        """
+        分页查询 contents 表,并返回分页信息
+        :param page_num: 页码,从 1 开始
+        :param page_size: 每页数量
+        :param order_by: 排序条件,例如 {"id": "desc"} 或 {"created_at": "asc"}
+        :param dataset_id: 数据集 ID
+        :param doc_status: 文档状态(默认 1)
+        :return: dict,包含 entities、total_count、page、page_size、total_pages
+        """
+        if order_by is None:
+            order_by = {"id": "desc"}
+        offset = (page_num - 1) * page_size
+
+        # 动态拼接 where 条件
+        where_clauses = ["doc_status = %s"]
+        params = [doc_status]
+
+        if dataset_id:
+            where_clauses.append("dataset_id = %s")
+            params.append(dataset_id)
+
+        where_sql = " AND ".join(where_clauses)
+
+        # 动态拼接 order by
+        order_field, order_direction = list(order_by.items())[0]
+        order_sql = f"ORDER BY {order_field} {order_direction.upper()}"
+
+        # 查询总数
+        count_query = f"SELECT COUNT(*) as total_count FROM contents WHERE {where_sql};"
+        count_result = await self.pool.async_fetch(
+            query=count_query, params=tuple(params)
+        )
+        total_count = count_result[0]["total_count"] if count_result else 0
+
+        # 查询分页数据
+        query = f"""
+            SELECT * FROM contents
+            WHERE {where_sql}
+            {order_sql}
+            LIMIT %s OFFSET %s;
+        """
+        params.extend([page_size, offset])
+        entities = await self.pool.async_fetch(query=query, params=tuple(params))
+
+        total_pages = (total_count + page_size - 1) // page_size  # 向上取整
+
+        return {
+            "entities": entities,
+            "total_count": total_count,
+            "page": page_num,
+            "page_size": page_size,
+            "total_pages": total_pages,
+        }

+ 1 - 315
applications/utils/mysql/mapper.py

@@ -1,20 +1,4 @@
-import json
-
-from applications.config import Chunk
-
-
-class TaskConst:
-    INIT_STATUS = 0
-    PROCESSING_STATUS = 1
-    FINISHED_STATUS = 2
-    FAILED_STATUS = 3
-
-    CHUNK_USEFUL_STATUS = 1
-
-
-class BaseMySQLClient(TaskConst):
-    def __init__(self, pool):
-        self.pool = pool
+from .base import BaseMySQLClient
 
 
 class Dataset(BaseMySQLClient):
@@ -51,304 +35,6 @@ class Dataset(BaseMySQLClient):
         return await self.pool.async_fetch(query=query, params=(name, status))
 
 
-class Contents(BaseMySQLClient):
-    async def insert_content(self, doc_id, text, text_type, title, dataset_id):
-        query = """
-            INSERT IGNORE INTO contents
-                (doc_id, text, text_type, title, dataset_id)
-            VALUES (%s, %s, %s, %s, %s);
-        """
-        return await self.pool.async_save(
-            query=query, params=(doc_id, text, text_type, title, dataset_id)
-        )
-
-    async def update_content_info(self, doc_id, text, text_type, title, dataset_id):
-        query = """
-            UPDATE contents 
-            SET text = %s, text_type = %s, title = %s, dataset_id = %s, status = %s
-            WHERE doc_id = %s;
-        """
-        return await self.pool.async_save(
-            query=query,
-            params=(text, text_type, title, dataset_id, self.INIT_STATUS, doc_id),
-        )
-
-    async def update_content_status(self, doc_id, ori_status, new_status):
-        query = """
-            UPDATE contents
-            SET status = %s
-            WHERE doc_id = %s AND status = %s;
-        """
-        return await self.pool.async_save(
-            query=query, params=(new_status, doc_id, ori_status)
-        )
-
-    async def update_dataset_status(self, dataset_id, ori_status, new_status):
-        query = """
-            UPDATE contents
-            SET status = %s
-            WHERE dataset_id = %s AND status = %s;
-        """
-        return await self.pool.async_save(
-            query=query, params=(new_status, dataset_id, ori_status)
-        )
-
-    async def update_doc_status(self, doc_id, ori_status, new_status):
-        """
-        this function is to change the using status of each document
-        :param doc_id:
-        :param ori_status:
-        :param new_status:
-        :return:
-        """
-        query = """
-            UPDATE contents SET doc_status = %s WHERE doc_id = %s AND doc_status = %s;
-        """
-        return await self.pool.async_save(
-            query=query, params=(new_status, doc_id, ori_status)
-        )
-
-    async def select_count(self, dataset_id, doc_status=1):
-        query = """
-            SELECT count(*) AS count FROM contents WHERE dataset_id = %s AND doc_status = %s;
-        """
-        rows = await self.pool.async_fetch(query=query, params=(dataset_id, doc_status))
-        return rows[0]["count"] if rows else 0
-
-    async def select_content_by_doc_id(self, doc_id):
-        query = """SELECT * FROM contents WHERE doc_id = %s;"""
-        return await self.pool.async_fetch(query=query, params=(doc_id,))
-
-    async def select_contents(
-        self,
-        page_num: int,
-        page_size: int,
-        order_by=None,
-        dataset_id: int = None,
-        doc_status: int = 1,
-    ):
-        """
-        分页查询 contents 表,并返回分页信息
-        :param page_num: 页码,从 1 开始
-        :param page_size: 每页数量
-        :param order_by: 排序条件,例如 {"id": "desc"} 或 {"created_at": "asc"}
-        :param dataset_id: 数据集 ID
-        :param doc_status: 文档状态(默认 1)
-        :return: dict,包含 entities、total_count、page、page_size、total_pages
-        """
-        if order_by is None:
-            order_by = {"id": "desc"}
-        offset = (page_num - 1) * page_size
-
-        # 动态拼接 where 条件
-        where_clauses = ["doc_status = %s"]
-        params = [doc_status]
-
-        if dataset_id:
-            where_clauses.append("dataset_id = %s")
-            params.append(dataset_id)
-
-        where_sql = " AND ".join(where_clauses)
-
-        # 动态拼接 order by
-        order_field, order_direction = list(order_by.items())[0]
-        order_sql = f"ORDER BY {order_field} {order_direction.upper()}"
-
-        # 查询总数
-        count_query = f"SELECT COUNT(*) as total_count FROM contents WHERE {where_sql};"
-        count_result = await self.pool.async_fetch(
-            query=count_query, params=tuple(params)
-        )
-        total_count = count_result[0]["total_count"] if count_result else 0
-
-        # 查询分页数据
-        query = f"""
-            SELECT * FROM contents
-            WHERE {where_sql}
-            {order_sql}
-            LIMIT %s OFFSET %s;
-        """
-        params.extend([page_size, offset])
-        entities = await self.pool.async_fetch(query=query, params=tuple(params))
-
-        total_pages = (total_count + page_size - 1) // page_size  # 向上取整
-
-        return {
-            "entities": entities,
-            "total_count": total_count,
-            "page": page_num,
-            "page_size": page_size,
-            "total_pages": total_pages,
-        }
-
-
-class ContentChunks(BaseMySQLClient):
-    async def insert_chunk(self, chunk: Chunk) -> int:
-        query = """
-            INSERT IGNORE INTO content_chunks
-                (chunk_id, doc_id, text, tokens, topic_purity, text_type, dataset_id, status) 
-                VALUES (%s, %s, %s, %s, %s, %s, %s, %s);
-        """
-        return await self.pool.async_save(
-            query=query,
-            params=(
-                chunk.chunk_id,
-                chunk.doc_id,
-                chunk.text,
-                chunk.tokens,
-                chunk.topic_purity,
-                chunk.text_type,
-                chunk.dataset_id,
-                chunk.status,
-            ),
-        )
-
-    async def update_chunk_status(self, doc_id, chunk_id, ori_status, new_status):
-        query = """
-            UPDATE content_chunks
-            SET chunk_status = %s 
-            WHERE doc_id = %s AND chunk_id = %s AND chunk_status = %s and status = %s;
-        """
-        return await self.pool.async_save(
-            query=query,
-            params=(new_status, doc_id, chunk_id, ori_status, self.CHUNK_USEFUL_STATUS),
-        )
-
-    async def update_embedding_status(self, doc_id, chunk_id, ori_status, new_status):
-        query = """
-            UPDATE content_chunks
-            SET embedding_status = %s 
-            WHERE doc_id = %s AND chunk_id = %s AND embedding_status = %s;
-        """
-        return await self.pool.async_save(
-            query=query, params=(new_status, doc_id, chunk_id, ori_status)
-        )
-
-    async def set_chunk_result(self, chunk: Chunk, ori_status, new_status):
-        query = """
-            UPDATE content_chunks
-            SET summary = %s, topic = %s, domain = %s, task_type = %s, concepts = %s, 
-                keywords = %s, questions = %s, chunk_status = %s, entities = %s
-            WHERE doc_id = %s AND chunk_id = %s AND chunk_status = %s;
-        """
-        return await self.pool.async_save(
-            query=query,
-            params=(
-                chunk.summary,
-                chunk.topic,
-                chunk.domain,
-                chunk.task_type,
-                json.dumps(chunk.concepts),
-                json.dumps(chunk.keywords),
-                json.dumps(chunk.questions),
-                new_status,
-                json.dumps(chunk.entities),
-                chunk.doc_id,
-                chunk.chunk_id,
-                ori_status,
-            ),
-        )
-
-    async def update_es_status(self, doc_id, chunk_id, ori_status, new_status):
-        query = """
-            UPDATE content_chunks SET es_status = %s
-            WHERE doc_id = %s AND chunk_id = %s AND es_status = %s;
-        """
-        return await self.pool.async_save(
-            query=query, params=(new_status, doc_id, chunk_id, ori_status)
-        )
-
-    async def update_doc_chunk_status(self, doc_id, chunk_id, ori_status, new_status):
-        query = """
-            UPDATE content_chunks set status = %s 
-            WHERE doc_id = %s AND chunk_id = %s AND status = %s;
-        """
-        return await self.pool.async_save(
-            query=query, params=(new_status, doc_id, chunk_id, ori_status)
-        )
-
-    async def update_doc_status(self, doc_id, ori_status, new_status):
-        query = """
-            UPDATE content_chunks set status = %s 
-            WHERE doc_id = %s AND status = %s;
-        """
-        return await self.pool.async_save(
-            query=query, params=(new_status, doc_id, ori_status)
-        )
-
-    async def update_dataset_status(self, dataset_id, ori_status, new_status):
-        query = """
-            UPDATE content_chunks set status = %s 
-            WHERE dataset_id = %s AND status = %s;
-        """
-        return await self.pool.async_save(
-            query=query, params=(new_status, dataset_id, ori_status)
-        )
-
-    async def select_chunk_content(self, doc_id, chunk_id):
-        query = """
-            SELECT * FROM content_chunks WHERE doc_id = %s AND chunk_id = %s;
-        """
-        return await self.pool.async_fetch(query=query, params=(doc_id, chunk_id))
-
-    async def select_chunk_contents(
-        self,
-        page_num: int,
-        page_size: int,
-        order_by: dict = {"chunk_id": "asc"},
-        doc_id: str = None,
-        doc_status: int = None,
-    ):
-        offset = (page_num - 1) * page_size
-
-        # 动态拼接 where 条件
-        where_clauses = []
-        params = []
-
-        if doc_id:
-            where_clauses.append("doc_id = %s")
-            params.append(doc_id)
-
-        if doc_status:
-            where_clauses.append("doc_status = %s")
-            params.append(doc_status)
-
-        where_sql = " AND ".join(where_clauses)
-
-        # 动态拼接 order by
-        order_field, order_direction = list(order_by.items())[0]
-        order_sql = f"ORDER BY {order_field} {order_direction.upper()}"
-
-        # 查询总数
-        count_query = (
-            f"SELECT COUNT(*) as total_count FROM content_chunks WHERE {where_sql};"
-        )
-        count_result = await self.pool.async_fetch(
-            query=count_query, params=tuple(params)
-        )
-        total_count = count_result[0]["total_count"] if count_result else 0
-
-        # 查询分页数据
-        query = f"""
-            SELECT * FROM content_chunks
-            WHERE {where_sql}
-            {order_sql}
-            LIMIT %s OFFSET %s;
-        """
-        params.extend([page_size, offset])
-        entities = await self.pool.async_fetch(query=query, params=tuple(params))
-
-        total_pages = (total_count + page_size - 1) // page_size  # 向上取整
-        print(total_pages)
-        return {
-            "entities": entities,
-            "total_count": total_count,
-            "page": page_num,
-            "page_size": page_size,
-            "total_pages": total_pages,
-        }
-
-
 class ChatResult(BaseMySQLClient):
     async def insert_chat_result(
         self, query_text, dataset_ids, search_res, chat_res, score, has_answer

+ 9 - 0
applications/utils/neo4j/__init__.py

@@ -0,0 +1,9 @@
+from .repository import AsyncNeo4jRepository
+from .models import Document, ChunkRelations, GraphChunk
+
+__all__ = [
+    "AsyncNeo4jRepository",
+    "Document",
+    "ChunkRelations",
+    "GraphChunk",
+]

+ 95 - 0
applications/utils/neo4j/models.py

@@ -0,0 +1,95 @@
+from dataclasses import dataclass
+from typing import List
+
+
+@dataclass
+class Document:
+    doc_id: str
+    dataset_id: int
+
+
+@dataclass
+class GraphChunk:
+    milvus_id: int
+    chunk_id: int
+    doc_id: str
+    topic: str
+    domain: str
+    text_type: int
+    task_type: str
+
+
+@dataclass
+class ChunkRelations:
+    entities: List[str]
+    concepts: List[str]
+    keywords: List[str]
+    domain: str
+    topic: str
+
+
+QUERY = """
+// 1) Document & GraphChunk
+MERGE (d:Document {doc_id: $doc_id})
+  ON CREATE SET d.dataset_id = $dataset_id
+  SET d.dataset_id = $dataset_id
+
+MERGE (gc:GraphChunk {milvus_id: $milvus_id})
+  ON CREATE SET gc.chunk_id = $chunk_id, gc.doc_id = $doc_id
+  SET gc.topic     = $topic,
+      gc.domain    = $domain,
+      gc.text_type = $text_type,
+      gc.task_type = $task_type,
+      gc.doc_id    = $doc_id
+
+MERGE (gc)-[:BELONGS_TO]->(d)
+MERGE (d)-[:HAS_CHUNK]->(gc)
+
+// 2) 参数准备
+WITH gc,
+     COALESCE($entities,  []) AS entities,
+     COALESCE($concepts,  []) AS concepts,
+     COALESCE($keywords,  []) AS keywords,
+     $domain_name AS domain_name,
+     $topic_name  AS topic_name
+
+// 3) Entities
+UNWIND entities AS e_name
+  WITH gc, e_name, concepts, keywords, domain_name, topic_name
+  WITH gc, TRIM(e_name) AS e_name, concepts, keywords, domain_name, topic_name
+  WHERE e_name <> ""
+  MERGE (e:Entity {name: e_name})
+  MERGE (gc)-[:HAS_ENTITY]->(e)
+
+// 4) Concepts
+WITH gc, concepts, keywords, domain_name, topic_name
+UNWIND concepts AS c_name
+  WITH gc, c_name, keywords, domain_name, topic_name
+  WITH gc, TRIM(c_name) AS c_name, keywords, domain_name, topic_name
+  WHERE c_name <> ""
+  MERGE (co:Concept {name: c_name})
+  MERGE (gc)-[:HAS_CONCEPT]->(co)
+
+// 5) Keywords
+WITH gc, keywords, domain_name, topic_name
+UNWIND keywords AS k_name
+  WITH gc, k_name, domain_name, topic_name
+  WITH gc, TRIM(k_name) AS k_name, domain_name, topic_name
+  WHERE k_name <> ""
+  MERGE (k:Keyword {name: k_name})
+  MERGE (gc)-[:HAS_KEYWORD]->(k)
+
+// 6) Domain(条件执行,用 FOREACH 替代 CALL)
+WITH gc, domain_name, topic_name
+FOREACH (_ IN CASE WHEN domain_name IS NOT NULL AND TRIM(domain_name) <> "" THEN [1] ELSE [] END |
+  MERGE (d_node:Domain {name: TRIM(domain_name)})
+  MERGE (gc)-[:HAS_DOMAIN]->(d_node)
+)
+
+// 7) Topic(条件执行,用 FOREACH 替代 CALL)
+WITH gc, topic_name
+FOREACH (_ IN CASE WHEN topic_name IS NOT NULL AND TRIM(topic_name) <> "" THEN [1] ELSE [] END |
+  MERGE (t:Topic {name: TRIM(topic_name)})
+  MERGE (gc)-[:HAS_TOPIC]->(t)
+)
+"""

+ 28 - 0
applications/utils/neo4j/query.py

@@ -0,0 +1,28 @@
+class AsyncNeo4jQuery:
+    def __init__(self, neo4j):
+        self.neo4j = neo4j
+
+    async def close(self):
+        await self.neo4j.close()
+
+    async def get_document_by_id(self, doc_id: str):
+        query = """
+        MATCH (d:Document {doc_id: $doc_id})
+        OPTIONAL MATCH (d)-[:HAS_CHUNK]->(c:Chunk)
+        RETURN d, collect(c) as chunks
+        """
+        async with self.neo4j.session() as session:
+            result = await session.run(query, doc_id=doc_id)
+            return [
+                record.data() for record in await result.consume().records
+            ]  # 注意 result 需要 async 迭代
+
+    async def search_chunks_by_topic(self, topic: str):
+        query = """
+        MATCH (c:Chunk {topic: $topic})
+        OPTIONAL MATCH (c)-[:HAS_ENTITY]->(e:Entity)
+        RETURN c, collect(e.name) as entities
+        """
+        async with self.neo4j.session() as session:
+            result = await session.run(query, topic=topic)
+            return [record.data() async for record in result]

+ 28 - 0
applications/utils/neo4j/repository.py

@@ -0,0 +1,28 @@
+from .models import Document, GraphChunk, ChunkRelations, QUERY
+
+
+class AsyncNeo4jRepository:
+    def __init__(self, neo4j):
+        self.neo4j = neo4j
+
+    async def add_document_with_chunk(
+        self, doc: Document, chunk: GraphChunk, relations: ChunkRelations
+    ):
+        async with self.neo4j.session() as session:
+            await session.run(
+                QUERY,
+                milvus_id=chunk.milvus_id,
+                doc_id=doc.doc_id,
+                dataset_id=doc.dataset_id,
+                chunk_id=chunk.chunk_id,
+                topic=chunk.topic,
+                domain=chunk.domain,
+                text_type=chunk.text_type,
+                task_type=chunk.task_type,
+                entities=relations.entities,
+                concepts=relations.concepts,
+                keywords=relations.keywords,
+                domain_name=relations.domain,
+                topic_name=relations.topic,
+            )
+        print(f"✅ {doc.doc_id} - {chunk.chunk_id} 已写入 Neo4j (async)")

+ 2 - 1
requirements.txt

@@ -20,4 +20,5 @@ quart-cors==0.8.0
 tiktoken==0.11.0
 uvloop==0.21.0
 elasticsearch==8.17.2
-scikit-learn==1.7.2
+scikit-learn==1.7.2
+neo4j==5.28.2

+ 14 - 1
routes/buleprint.py

@@ -9,7 +9,7 @@ from quart_cors import cors
 
 from applications.api import get_basic_embedding
 from applications.api import get_img_embedding
-from applications.async_task import AutoRechunkTask
+from applications.async_task import AutoRechunkTask, BuildGraph
 from applications.async_task import ChunkEmbeddingTask, DeleteTask
 from applications.config import (
     DEFAULT_MODEL,
@@ -550,3 +550,16 @@ async def auto_rechunk():
     auto_rechunk_task = AutoRechunkTask(mysql_client=resource.mysql_client)
     process_cnt = await auto_rechunk_task.deal()
     return jsonify({"status_code": 200, "detail": "success", "cnt": process_cnt})
+
+
+@server_bp.route("/build_graph", methods=["POST"])
+async def delete_task():
+    body = await request.get_json()
+    doc_id: str = body.get("doc_id")
+    if not doc_id:
+        return jsonify({"status_code": 500, "detail": "docId not found", "data": {}})
+
+    resource = get_resource_manager()
+    build_graph_task = BuildGraph(neo4j=resource.graph_client, es_client=resource.es_client)
+    await build_graph_task.deal(doc_id)
+    return jsonify({"status_code": 200, "detail": "success", "data": {}})