Prechádzať zdrojové kódy

新增 rechunk 功能

luojunhui 2 týždňov pred
rodič
commit
bcf076ac4d

+ 22 - 5
applications/async_task/chunk_task.py

@@ -29,11 +29,26 @@ class ChunkEmbeddingTask(TopicAwarePackerV2):
         self.chunk_manager = ContentChunks(self.mysql_client)
 
     async def _chunk_each_content(
-        self, doc_id: str, text: str, text_type: int, title: str, dataset_id: int
+        self,
+        doc_id: str,
+        text: str,
+        text_type: int,
+        title: str,
+        dataset_id: int,
+        re_chunk: bool,
     ) -> List[Chunk]:
-        flag = await self.content_manager.insert_content(
-            doc_id, text, text_type, title, dataset_id
-        )
+        if re_chunk:
+            flag = await self.content_manager.update_content_info(
+                doc_id=doc_id,
+                text=text,
+                text_type=text_type,
+                title=title,
+                dataset_id=dataset_id,
+            )
+        else:
+            flag = await self.content_manager.insert_content(
+                doc_id, text, text_type, title, dataset_id
+            )
         if not flag:
             return []
         else:
@@ -203,6 +218,8 @@ class ChunkEmbeddingTask(TopicAwarePackerV2):
         text, title = text.strip(), title.strip()
         text_type = data.get("text_type", 1)
         dataset_id = data.get("dataset_id", 0)  # 默认知识库 id 为 0
+        re_chunk = data.get("re_chunk", False)
+
         if not text:
             return None
 
@@ -210,7 +227,7 @@ class ChunkEmbeddingTask(TopicAwarePackerV2):
 
         async def _process():
             chunks = await self._chunk_each_content(
-                self.doc_id, text, text_type, title, dataset_id
+                self.doc_id, text, text_type, title, dataset_id, re_chunk
             )
             if not chunks:
                 return

+ 11 - 0
applications/utils/mysql/mapper.py

@@ -56,6 +56,17 @@ class Contents(BaseMySQLClient):
             query=query, params=(doc_id, text, text_type, title, dataset_id)
         )
 
+    async def update_content_info(self, doc_id, text, text_type, title, dataset_id):
+        query = """
+            UPDATE contents 
+            SET text = %s, text_type = %s, title = %s, dataset_id = %s, status = %s
+            WHERE doc_id = %s;
+        """
+        return await self.pool.async_save(
+            query=query,
+            params=(text, text_type, title, dataset_id, self.INIT_STATUS, doc_id),
+        )
+
     async def update_content_status(self, doc_id, ori_status, new_status):
         query = """
             UPDATE contents

+ 9 - 1
routes/buleprint.py

@@ -63,11 +63,19 @@ async def delete():
 async def chunk():
     body = await request.get_json()
     text = body.get("text", "")
+    ori_doc_id = body.get("doc_id")
     text = text.strip()
     if not text:
         return jsonify({"error": "error  text"})
     resource = get_resource_manager()
-    doc_id = f"doc-{uuid.uuid4()}"
+
+    # generate doc id
+    if ori_doc_id:
+        body["re_chunk"] = True
+        doc_id = ori_doc_id
+    else:
+        doc_id = f"doc-{uuid.uuid4()}"
+
     chunk_task = ChunkEmbeddingTask(doc_id=doc_id, resource=resource)
     doc_id = await chunk_task.deal(body)
     return jsonify({"doc_id": doc_id})