chunk_task.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import asyncio
  2. import uuid
  3. from typing import List
  4. from applications.utils.mysql import ContentChunks, Contents
  5. from applications.utils.chunks import TopicAwareChunker, LLMClassifier
  6. from applications.config import DEFAULT_MODEL, Chunk, ChunkerConfig
  7. class ChunkTask(TopicAwareChunker):
  8. def __init__(self, mysql_pool, vector_pool, cfg: ChunkerConfig):
  9. super().__init__(cfg)
  10. self.content_chunk_processor = None
  11. self.contents_processor = None
  12. self.mysql_pool = mysql_pool
  13. self.vector_pool = vector_pool
  14. self.classifier = LLMClassifier()
  15. def init_processer(self):
  16. self.contents_processor = Contents(self.mysql_pool)
  17. self.content_chunk_processor = ContentChunks(self.mysql_pool)
  18. async def process_content(self, doc_id, text) -> List[Chunk]:
  19. flag = await self.contents_processor.insert_content(doc_id, text)
  20. if not flag:
  21. return []
  22. else:
  23. raw_chunks = await self.chunk(text)
  24. if not raw_chunks:
  25. await self.contents_processor.update_content_status(
  26. doc_id=doc_id, ori_status=self.INIT_STATUS, new_status=self.FAILED_STATUS
  27. )
  28. return []
  29. affected_rows = await self.contents_processor.update_content_status(
  30. doc_id=doc_id, ori_status=self.INIT_STATUS, new_status=self.PROCESSING_STATUS
  31. )
  32. print(affected_rows)
  33. return raw_chunks
  34. async def process_each_chunk(self, chunk: Chunk):
  35. # insert
  36. flag = await self.content_chunk_processor.insert_chunk(chunk)
  37. if not flag:
  38. return
  39. acquire_lock = await self.content_chunk_processor.update_chunk_status(
  40. doc_id=chunk.doc_id, chunk_id=chunk.chunk_id, ori_status=self.INIT_STATUS, new_status=self.PROCESSING_STATUS
  41. )
  42. if not acquire_lock:
  43. return
  44. completion = await self.classifier.classify_chunk(chunk)
  45. if not completion:
  46. await self.content_chunk_processor.update_chunk_status(
  47. doc_id=chunk.doc_id, chunk_id=chunk.chunk_id, ori_status=self.PROCESSING_STATUS, new_status=self.FAILED_STATUS
  48. )
  49. update_flag = await self.content_chunk_processor.set_chunk_result(
  50. chunk=completion, new_status=self.FINISHED_STATUS, ori_status=self.PROCESSING_STATUS
  51. )
  52. if not update_flag:
  53. await self.content_chunk_processor.update_chunk_status(
  54. doc_id=chunk.doc_id, chunk_id=chunk.chunk_id, ori_status=self.PROCESSING_STATUS, new_status=self.FAILED_STATUS
  55. )
  56. async def deal(self, data):
  57. text = data.get("text")
  58. if not text:
  59. return None
  60. self.init_processer()
  61. doc_id = f"doc-{uuid.uuid4()}"
  62. async def _process():
  63. chunks = await self.process_content(doc_id, text)
  64. if not chunks:
  65. return
  66. # 开始分batch
  67. async with asyncio.TaskGroup() as tg:
  68. for chunk in chunks:
  69. tg.create_task(self.process_each_chunk(chunk))
  70. await self.contents_processor.update_content_status(
  71. doc_id=doc_id, ori_status=self.PROCESSING_STATUS, new_status=self.FINISHED_STATUS
  72. )
  73. await _process()
  74. # asyncio.create_task(_process())
  75. return doc_id