import asyncio import uuid from typing import List from applications.utils.mysql import ContentChunks, Contents from applications.utils.chunks import TopicAwareChunker, LLMClassifier from applications.config import DEFAULT_MODEL, Chunk, ChunkerConfig class ChunkTask(TopicAwareChunker): def __init__(self, mysql_pool, vector_pool, cfg: ChunkerConfig): super().__init__(cfg) self.content_chunk_processor = None self.contents_processor = None self.mysql_pool = mysql_pool self.vector_pool = vector_pool self.classifier = LLMClassifier() def init_processer(self): self.contents_processor = Contents(self.mysql_pool) self.content_chunk_processor = ContentChunks(self.mysql_pool) async def process_content(self, doc_id, text) -> List[Chunk]: flag = await self.contents_processor.insert_content(doc_id, text) if not flag: return [] else: raw_chunks = await self.chunk(text) if not raw_chunks: await self.contents_processor.update_content_status( doc_id=doc_id, ori_status=self.INIT_STATUS, new_status=self.FAILED_STATUS ) return [] affected_rows = await self.contents_processor.update_content_status( doc_id=doc_id, ori_status=self.INIT_STATUS, new_status=self.PROCESSING_STATUS ) print(affected_rows) return raw_chunks async def process_each_chunk(self, chunk: Chunk): # insert flag = await self.content_chunk_processor.insert_chunk(chunk) if not flag: return acquire_lock = await self.content_chunk_processor.update_chunk_status( doc_id=chunk.doc_id, chunk_id=chunk.chunk_id, ori_status=self.INIT_STATUS, new_status=self.PROCESSING_STATUS ) if not acquire_lock: return completion = await self.classifier.classify_chunk(chunk) if not completion: await self.content_chunk_processor.update_chunk_status( doc_id=chunk.doc_id, chunk_id=chunk.chunk_id, ori_status=self.PROCESSING_STATUS, new_status=self.FAILED_STATUS ) update_flag = await self.content_chunk_processor.set_chunk_result( chunk=completion, new_status=self.FINISHED_STATUS, ori_status=self.PROCESSING_STATUS ) if not update_flag: await self.content_chunk_processor.update_chunk_status( doc_id=chunk.doc_id, chunk_id=chunk.chunk_id, ori_status=self.PROCESSING_STATUS, new_status=self.FAILED_STATUS ) async def deal(self, data): text = data.get("text") if not text: return None self.init_processer() doc_id = f"doc-{uuid.uuid4()}" async def _process(): chunks = await self.process_content(doc_id, text) if not chunks: return # 开始分batch async with asyncio.TaskGroup() as tg: for chunk in chunks: tg.create_task(self.process_each_chunk(chunk)) await self.contents_processor.update_content_status( doc_id=doc_id, ori_status=self.PROCESSING_STATUS, new_status=self.FINISHED_STATUS ) await _process() # asyncio.create_task(_process()) return doc_id