12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- import asyncio
- import uuid
- from typing import List
- from applications.utils.mysql import ContentChunks, Contents
- from applications.utils.chunks import TopicAwareChunker, LLMClassifier
- from applications.config import DEFAULT_MODEL, Chunk, ChunkerConfig
- class ChunkTask(TopicAwareChunker):
- def __init__(self, mysql_pool, vector_pool, cfg: ChunkerConfig):
- super().__init__(cfg)
- self.content_chunk_processor = None
- self.contents_processor = None
- self.mysql_pool = mysql_pool
- self.vector_pool = vector_pool
- self.classifier = LLMClassifier()
- def init_processer(self):
- self.contents_processor = Contents(self.mysql_pool)
- self.content_chunk_processor = ContentChunks(self.mysql_pool)
- async def process_content(self, doc_id, text) -> List[Chunk]:
- flag = await self.contents_processor.insert_content(doc_id, text)
- if not flag:
- return []
- else:
- raw_chunks = await self.chunk(text)
- if not raw_chunks:
- await self.contents_processor.update_content_status(
- doc_id=doc_id, ori_status=self.INIT_STATUS, new_status=self.FAILED_STATUS
- )
- return []
- affected_rows = await self.contents_processor.update_content_status(
- doc_id=doc_id, ori_status=self.INIT_STATUS, new_status=self.PROCESSING_STATUS
- )
- print(affected_rows)
- return raw_chunks
- async def process_each_chunk(self, chunk: Chunk):
- # insert
- flag = await self.content_chunk_processor.insert_chunk(chunk)
- if not flag:
- return
- acquire_lock = await self.content_chunk_processor.update_chunk_status(
- doc_id=chunk.doc_id, chunk_id=chunk.chunk_id, ori_status=self.INIT_STATUS, new_status=self.PROCESSING_STATUS
- )
- if not acquire_lock:
- return
- completion = await self.classifier.classify_chunk(chunk)
- if not completion:
- await self.content_chunk_processor.update_chunk_status(
- doc_id=chunk.doc_id, chunk_id=chunk.chunk_id, ori_status=self.PROCESSING_STATUS, new_status=self.FAILED_STATUS
- )
- update_flag = await self.content_chunk_processor.set_chunk_result(
- chunk=completion, new_status=self.FINISHED_STATUS, ori_status=self.PROCESSING_STATUS
- )
- if not update_flag:
- await self.content_chunk_processor.update_chunk_status(
- doc_id=chunk.doc_id, chunk_id=chunk.chunk_id, ori_status=self.PROCESSING_STATUS, new_status=self.FAILED_STATUS
- )
- async def deal(self, data):
- text = data.get("text")
- if not text:
- return None
- self.init_processer()
- doc_id = f"doc-{uuid.uuid4()}"
- async def _process():
- chunks = await self.process_content(doc_id, text)
- if not chunks:
- return
- # 开始分batch
- async with asyncio.TaskGroup() as tg:
- for chunk in chunks:
- tg.create_task(self.process_each_chunk(chunk))
- await self.contents_processor.update_content_status(
- doc_id=doc_id, ori_status=self.PROCESSING_STATUS, new_status=self.FINISHED_STATUS
- )
- await _process()
- # asyncio.create_task(_process())
- return doc_id
|