chunk_task.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import asyncio
  2. import uuid
  3. from typing import List
  4. from applications.api import get_basic_embedding
  5. from applications.utils.async_utils import run_tasks_with_asyncio_task_group
  6. from applications.utils.mysql import ContentChunks, Contents
  7. from applications.utils.chunks import TopicAwareChunker, LLMClassifier
  8. from applications.utils.milvus import async_insert_chunk
  9. from applications.config import Chunk, ChunkerConfig, DEFAULT_MODEL
  10. class ChunkEmbeddingTask(TopicAwareChunker):
  11. def __init__(self, mysql_pool, vector_pool, cfg: ChunkerConfig):
  12. super().__init__(cfg)
  13. self.content_chunk_processor = None
  14. self.contents_processor = None
  15. self.mysql_pool = mysql_pool
  16. self.vector_pool = vector_pool
  17. self.classifier = LLMClassifier()
  18. @staticmethod
  19. async def get_embedding_list(text: str) -> List:
  20. return await get_basic_embedding(text=text, model=DEFAULT_MODEL, dev=True)
  21. def init_processer(self):
  22. self.contents_processor = Contents(self.mysql_pool)
  23. self.content_chunk_processor = ContentChunks(self.mysql_pool)
  24. async def process_content(self, doc_id, text) -> List[Chunk]:
  25. flag = await self.contents_processor.insert_content(doc_id, text)
  26. if not flag:
  27. return []
  28. else:
  29. raw_chunks = await self.chunk(text)
  30. if not raw_chunks:
  31. await self.contents_processor.update_content_status(
  32. doc_id=doc_id,
  33. ori_status=self.INIT_STATUS,
  34. new_status=self.FAILED_STATUS,
  35. )
  36. return []
  37. affected_rows = await self.contents_processor.update_content_status(
  38. doc_id=doc_id,
  39. ori_status=self.INIT_STATUS,
  40. new_status=self.PROCESSING_STATUS,
  41. )
  42. return raw_chunks
  43. async def process_each_chunk(self, chunk: Chunk):
  44. # insert
  45. flag = await self.content_chunk_processor.insert_chunk(chunk)
  46. if not flag:
  47. return
  48. acquire_lock = await self.content_chunk_processor.update_chunk_status(
  49. doc_id=chunk.doc_id,
  50. chunk_id=chunk.chunk_id,
  51. ori_status=self.INIT_STATUS,
  52. new_status=self.PROCESSING_STATUS,
  53. )
  54. if not acquire_lock:
  55. return
  56. completion = await self.classifier.classify_chunk(chunk)
  57. if not completion:
  58. await self.content_chunk_processor.update_chunk_status(
  59. doc_id=chunk.doc_id,
  60. chunk_id=chunk.chunk_id,
  61. ori_status=self.PROCESSING_STATUS,
  62. new_status=self.FAILED_STATUS,
  63. )
  64. return
  65. update_flag = await self.content_chunk_processor.set_chunk_result(
  66. chunk=completion,
  67. ori_status=self.PROCESSING_STATUS,
  68. new_status=self.FINISHED_STATUS,
  69. )
  70. if not update_flag:
  71. await self.content_chunk_processor.update_chunk_status(
  72. doc_id=chunk.doc_id,
  73. chunk_id=chunk.chunk_id,
  74. ori_status=self.PROCESSING_STATUS,
  75. new_status=self.FAILED_STATUS,
  76. )
  77. return
  78. await self.save_to_milvus(completion)
  79. async def save_to_milvus(self, chunk: Chunk):
  80. """
  81. :param chunk: each single chunk
  82. :return:
  83. """
  84. # 抢锁
  85. acquire_lock = await self.content_chunk_processor.update_embedding_status(
  86. doc_id=chunk.doc_id,
  87. chunk_id=chunk.chunk_id,
  88. new_status=self.PROCESSING_STATUS,
  89. ori_status=self.INIT_STATUS,
  90. )
  91. if not acquire_lock:
  92. print(f"抢占-{chunk.doc_id}-{chunk.chunk_id}分块-embedding处理锁失败")
  93. return
  94. try:
  95. data = {
  96. "doc_id": chunk.doc_id,
  97. "chunk_id": chunk.chunk_id,
  98. "vector_text": await self.get_embedding_list(chunk.text),
  99. "vector_summary": await self.get_embedding_list(chunk.summary),
  100. "vector_questions": await self.get_embedding_list(
  101. ",".join(chunk.questions)
  102. ),
  103. "topic": chunk.topic,
  104. "domain": chunk.domain,
  105. "task_type": chunk.task_type,
  106. "summary": chunk.summary,
  107. "keywords": chunk.keywords,
  108. "concepts": chunk.concepts,
  109. "questions": chunk.questions,
  110. "topic_purity": chunk.topic_purity,
  111. "tokens": chunk.tokens,
  112. }
  113. await async_insert_chunk(self.vector_pool, data)
  114. await self.content_chunk_processor.update_embedding_status(
  115. doc_id=chunk.doc_id,
  116. chunk_id=chunk.chunk_id,
  117. ori_status=self.PROCESSING_STATUS,
  118. new_status=self.FINISHED_STATUS,
  119. )
  120. except Exception as e:
  121. await self.content_chunk_processor.update_embedding_status(
  122. doc_id=chunk.doc_id,
  123. chunk_id=chunk.chunk_id,
  124. ori_status=self.PROCESSING_STATUS,
  125. new_status=self.FAILED_STATUS,
  126. )
  127. print(f"存入向量数据库失败", e)
  128. async def deal(self, data):
  129. text = data.get("text")
  130. if not text:
  131. return None
  132. self.init_processer()
  133. doc_id = f"doc-{uuid.uuid4()}"
  134. async def _process():
  135. chunks = await self.process_content(doc_id, text)
  136. if not chunks:
  137. return
  138. await run_tasks_with_asyncio_task_group(
  139. task_list=chunks,
  140. handler=self.process_each_chunk,
  141. description="处理单篇文章分块",
  142. unit="chunk",
  143. max_concurrency=10,
  144. )
  145. await self.contents_processor.update_content_status(
  146. doc_id=doc_id,
  147. ori_status=self.PROCESSING_STATUS,
  148. new_status=self.FINISHED_STATUS,
  149. )
  150. asyncio.create_task(_process())
  151. return doc_id