Переглянути джерело

Merge branch 'master' into dev-xym-relocation

# Conflicts:
#	applications/utils/chat/chat_classifier.py
#	applications/utils/mysql/mapper.py
#	routes/buleprint.py
xueyiming 2 тижнів тому
батько
коміт
2ab7ecc2ec

+ 23 - 5
applications/async_task/chunk_task.py

@@ -29,11 +29,27 @@ class ChunkEmbeddingTask(TopicAwarePackerV2):
         self.chunk_manager = ContentChunks(self.mysql_client)
 
     async def _chunk_each_content(
-        self, doc_id: str, text: str, text_type: int, title: str, dataset_id: int
+        self,
+        doc_id: str,
+        text: str,
+        text_type: int,
+        title: str,
+        dataset_id: int,
+        re_chunk: bool,
     ) -> List[Chunk]:
-        flag = await self.content_manager.insert_content(
-            doc_id, text, text_type, title, dataset_id
-        )
+        if re_chunk:
+            await self.content_manager.update_content_info(
+                doc_id=doc_id,
+                text=text,
+                text_type=text_type,
+                title=title,
+                dataset_id=dataset_id,
+            )
+            flag = True
+        else:
+            flag = await self.content_manager.insert_content(
+                doc_id, text, text_type, title, dataset_id
+            )
         if not flag:
             return []
         else:
@@ -203,6 +219,8 @@ class ChunkEmbeddingTask(TopicAwarePackerV2):
         text, title = text.strip(), title.strip()
         text_type = data.get("text_type", 1)
         dataset_id = data.get("dataset_id", 0)  # 默认知识库 id 为 0
+        re_chunk = data.get("re_chunk", False)
+
         if not text:
             return None
 
@@ -210,7 +228,7 @@ class ChunkEmbeddingTask(TopicAwarePackerV2):
 
         async def _process():
             chunks = await self._chunk_each_content(
-                self.doc_id, text, text_type, title, dataset_id
+                self.doc_id, text, text_type, title, dataset_id, re_chunk
             )
             if not chunks:
                 return

+ 28 - 24
applications/utils/mysql/mapper.py

@@ -1,5 +1,4 @@
 import json
-from datetime import datetime
 
 from applications.config import Chunk
 
@@ -21,32 +20,29 @@ class BaseMySQLClient(TaskConst):
 class Dataset(BaseMySQLClient):
     async def update_dataset_status(self, dataset_id, ori_status, new_status):
         query = """
-            UPDATE dataset set status = %s where id = %s and status = %s;
+            UPDATE dataset SET status = %s WHERE id = %s AND status = %s;
         """
         return await self.pool.async_save(
-            query=query,
-            params=(new_status, dataset_id, ori_status),
+            query=query, params=(new_status, dataset_id, ori_status)
         )
 
     async def select_dataset(self, status=1):
         query = """
-            select * from dataset where status = %s;
+            SELECT * FROM dataset WHERE status = %s;
         """
         return await self.pool.async_fetch(query=query, params=(status,))
 
     async def add_dataset(self, name):
         query = """
-            insert into dataset (name, created_at, updated_at, status) values (%s, %s, %s, %s);
+            INSERT INTO dataset (name) VALUES (%s);
         """
-        return await self.pool.async_save(
-            query=query, params=(name, datetime.now(), datetime.now(), 1)
-        )
+        return await self.pool.async_save(query=query, params=(name,))
 
-    async def select_dataset_by_id(self, id, status=1):
+    async def select_dataset_by_id(self, id_, status: int = 1):
         query = """
-            select * from dataset where id = %s and status = %s;
+            SELECT * FROM dataset WHERE id = %s AND status = %s;
         """
-        return await self.pool.async_fetch(query=query, params=(id, status))
+        return await self.pool.async_fetch(query=query, params=(id_, status))
 
 
 class Contents(BaseMySQLClient):
@@ -60,6 +56,17 @@ class Contents(BaseMySQLClient):
             query=query, params=(doc_id, text, text_type, title, dataset_id)
         )
 
+    async def update_content_info(self, doc_id, text, text_type, title, dataset_id):
+        query = """
+            UPDATE contents 
+            SET text = %s, text_type = %s, title = %s, dataset_id = %s, status = %s
+            WHERE doc_id = %s;
+        """
+        return await self.pool.async_save(
+            query=query,
+            params=(text, text_type, title, dataset_id, self.INIT_STATUS, doc_id),
+        )
+
     async def update_content_status(self, doc_id, ori_status, new_status):
         query = """
             UPDATE contents
@@ -89,8 +96,7 @@ class Contents(BaseMySQLClient):
         :return:
         """
         query = """
-            UPDATE contents
-            SET doc_status = %s WHERE doc_id = %s and doc_status = %s;
+            UPDATE contents SET doc_status = %s WHERE doc_id = %s AND doc_status = %s;
         """
         return await self.pool.async_save(
             query=query, params=(new_status, doc_id, ori_status)
@@ -98,22 +104,20 @@ class Contents(BaseMySQLClient):
 
     async def select_count(self, dataset_id, doc_status=1):
         query = """
-            select count(*) as count from contents where dataset_id = %s and doc_status = %s;
+            SELECT count(*) AS count FROM contents WHERE dataset_id = %s AND doc_status = %s;
         """
         rows = await self.pool.async_fetch(query=query, params=(dataset_id, doc_status))
         return rows[0]["count"] if rows else 0
 
     async def select_content_by_doc_id(self, doc_id):
-        query = """
-            select * from contents where doc_id = %s;
-        """
+        query = """SELECT * FROM contents WHERE doc_id = %s;"""
         return await self.pool.async_fetch(query=query, params=(doc_id,))
 
     async def select_contents(
         self,
         page_num: int,
         page_size: int,
-        order_by: dict = {"id": "desc"},
+        order_by=None,
         dataset_id: int = None,
         doc_status: int = 1,
     ):
@@ -126,6 +130,8 @@ class Contents(BaseMySQLClient):
         :param doc_status: 文档状态(默认 1)
         :return: dict,包含 entities、total_count、page、page_size、total_pages
         """
+        if order_by is None:
+            order_by = {"id": "desc"}
         offset = (page_num - 1) * page_size
 
         # 动态拼接 where 条件
@@ -273,13 +279,11 @@ class ContentChunks(BaseMySQLClient):
             query=query, params=(new_status, dataset_id, ori_status)
         )
 
-    async def select_chunk_content(self, doc_id, chunk_id, status=1):
+    async def select_chunk_content(self, doc_id, chunk_id):
         query = """
-            select * from content_chunks where doc_id = %s and chunk_id = %s and status = %s;
+            SELECT * FROM content_chunks WHERE doc_id = %s AND chunk_id = %s;
         """
-        return await self.pool.async_fetch(
-            query=query, params=(doc_id, chunk_id, status)
-        )
+        return await self.pool.async_fetch(query=query, params=(doc_id, chunk_id))
 
     async def select_chunk_contents(
         self,

+ 17 - 12
routes/buleprint.py

@@ -10,7 +10,6 @@ from quart_cors import cors
 from applications.config import (
     DEFAULT_MODEL,
     LOCAL_MODEL_CONFIG,
-    ChunkerConfig,
     BASE_MILVUS_SEARCH_PARAMS,
 )
 from applications.resource import get_resource_manager
@@ -66,11 +65,19 @@ async def delete():
 async def chunk():
     body = await request.get_json()
     text = body.get("text", "")
+    ori_doc_id = body.get("doc_id")
     text = text.strip()
     if not text:
         return jsonify({"error": "error  text"})
     resource = get_resource_manager()
-    doc_id = f"doc-{uuid.uuid4()}"
+
+    # generate doc id
+    if ori_doc_id:
+        body["re_chunk"] = True
+        doc_id = ori_doc_id
+    else:
+        doc_id = f"doc-{uuid.uuid4()}"
+
     chunk_task = ChunkEmbeddingTask(doc_id=doc_id, resource=resource)
     doc_id = await chunk_task.deal(body)
     return jsonify({"doc_id": doc_id})
@@ -198,17 +205,15 @@ async def get_content():
 
     row = rows[0]
 
-    return jsonify(
-        {
-            "status_code": 200,
-            "detail": "success",
-            "data": {
-                "title": row.get("title", ""),
-                "text": row.get("text", ""),
-                "doc_id": row.get("doc_id", ""),
-            },
+    return jsonify({
+        "status_code": 200,
+        "detail": "success",
+        "data": {
+            "title": row.get("title", ""),
+            "text": row.get("text", ""),
+            "doc_id": row.get("doc_id", "")
         }
-    )
+    })
 
 
 @server_bp.route("/content/list", methods=["GET"])