Browse Source

chunk-init

luojunhui 3 weeks ago
parent
commit
fae07d6e60

+ 6 - 5
applications/async_task/chunk_task.py

@@ -27,12 +27,12 @@ class ChunkEmbeddingTask(TopicAwareChunker):
         self.contents_processor = Contents(self.mysql_pool)
         self.content_chunk_processor = ContentChunks(self.mysql_pool)
 
-    async def process_content(self, doc_id, text) -> List[Chunk]:
-        flag = await self.contents_processor.insert_content(doc_id, text)
+    async def process_content(self, doc_id: str, text: str, text_type: int) -> List[Chunk]:
+        flag = await self.contents_processor.insert_content(doc_id, text, text_type)
         if not flag:
             return []
         else:
-            raw_chunks = await self.chunk(text)
+            raw_chunks = await self.chunk(text, text_type)
             if not raw_chunks:
                 await self.contents_processor.update_content_status(
                     doc_id=doc_id,
@@ -41,7 +41,7 @@ class ChunkEmbeddingTask(TopicAwareChunker):
                 )
                 return []
 
-            affected_rows = await self.contents_processor.update_content_status(
+            await self.contents_processor.update_content_status(
                 doc_id=doc_id,
                 ori_status=self.INIT_STATUS,
                 new_status=self.PROCESSING_STATUS,
@@ -146,13 +146,14 @@ class ChunkEmbeddingTask(TopicAwareChunker):
     async def deal(self, data):
         text = data.get("text", "")
         text = text.strip()
+        text_type = data.get("text_type", 1)
         if not text:
             return None
 
         self.init_processer()
 
         async def _process():
-            chunks = await self.process_content(self.doc_id, text)
+            chunks = await self.process_content(self.doc_id, text, text_type)
             if not chunks:
                 return
 

+ 1 - 0
applications/config/base_chunk.py

@@ -11,6 +11,7 @@ class Chunk:
     domain: str = ""
     task_type: str = ""
     topic_purity: float = 1.0
+    text_type: int = 1
     summary: str = ""
     keywords: List[str] = field(default_factory=list)
     concepts: List[str] = field(default_factory=list)

+ 4 - 4
applications/utils/chunks/topic_aware_chunking.py

@@ -108,7 +108,7 @@ class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
         return np.stack(embs)
 
     def _pack_by_boundaries(
-        self, sentence_list: List[str], boundaries: List[int]
+        self, sentence_list: List[str], boundaries: List[int], text_type: int
     ) -> List[Chunk]:
         boundary_set = set(boundaries)
         chunks: List[Chunk] = []
@@ -136,7 +136,7 @@ class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
             tokens = num_tokens(text)
             chunk_id += 1
             chunk = Chunk(
-                doc_id=self.doc_id, chunk_id=chunk_id, text=text, tokens=tokens
+                doc_id=self.doc_id, chunk_id=chunk_id, text=text, tokens=tokens, text_type=text_type
             )
             chunks.append(chunk)
             start = end + 1
@@ -163,14 +163,14 @@ class TopicAwareChunker(BoundaryDetector, SplitTextIntoSentences):
         finally:
             self.cfg.boundary_threshold = orig
 
-    async def chunk(self, text: str) -> List[Chunk]:
+    async def chunk(self, text: str, text_type: int) -> List[Chunk]:
         sentence_list = self.jieba_sent_tokenize(text)
         if not sentence_list:
             return []
 
         sentences_embeddings = await self._encode_batch(sentence_list)
         boundaries = self.detect_boundaries(sentence_list, sentences_embeddings)
-        raw_chunks = self._pack_by_boundaries(sentence_list, boundaries)
+        raw_chunks = self._pack_by_boundaries(sentence_list, boundaries, text_type)
         return raw_chunks
 
 

+ 7 - 6
applications/utils/mysql/mapper.py

@@ -17,13 +17,13 @@ class BaseMySQLClient:
 
 class Contents(BaseMySQLClient):
 
-    async def insert_content(self, doc_id, text):
+    async def insert_content(self, doc_id, text, text_type):
         query = """
             INSERT IGNORE INTO contents
-                (doc_id, text)
-            VALUES (%s, %s);
+                (doc_id, text, text_type)
+            VALUES (%s, %s, %s);
         """
-        return await self.pool.async_save(query=query, params=(doc_id, text))
+        return await self.pool.async_save(query=query, params=(doc_id, text, text_type))
 
     async def update_content_status(self, doc_id, ori_status, new_status):
         query = """
@@ -41,8 +41,8 @@ class ContentChunks(BaseMySQLClient):
     async def insert_chunk(self, chunk: Chunk) -> int:
         query = """
             INSERT IGNORE INTO content_chunks
-                (chunk_id, doc_id, text, tokens, topic_purity) 
-                VALUES (%s, %s, %s, %s, %s);
+                (chunk_id, doc_id, text, tokens, topic_purity, text_type) 
+                VALUES (%s, %s, %s, %s, %s, %s);
         """
         return await self.pool.async_save(
             query=query,
@@ -52,6 +52,7 @@ class ContentChunks(BaseMySQLClient):
                 chunk.text,
                 chunk.tokens,
                 chunk.topic_purity,
+                chunk.text_type
             ),
         )