chunk_task.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import asyncio
  2. import uuid
  3. from typing import List
  4. from applications.api import get_basic_embedding
  5. from applications.utils.async_utils import run_tasks_with_asyncio_task_group
  6. from applications.utils.mysql import ContentChunks, Contents
  7. from applications.utils.chunks import TopicAwareChunker, LLMClassifier
  8. from applications.utils.milvus import async_insert_chunk
  9. from applications.config import Chunk, ChunkerConfig, DEFAULT_MODEL
  10. class ChunkEmbeddingTask(TopicAwareChunker):
  11. def __init__(self, mysql_pool, vector_pool, cfg: ChunkerConfig, doc_id):
  12. super().__init__(cfg, doc_id)
  13. self.content_chunk_processor = None
  14. self.contents_processor = None
  15. self.mysql_pool = mysql_pool
  16. self.vector_pool = vector_pool
  17. self.classifier = LLMClassifier()
  18. @staticmethod
  19. async def get_embedding_list(text: str) -> List:
  20. return await get_basic_embedding(text=text, model=DEFAULT_MODEL)
  21. def init_processer(self):
  22. self.contents_processor = Contents(self.mysql_pool)
  23. self.content_chunk_processor = ContentChunks(self.mysql_pool)
  24. async def process_content(
  25. self, doc_id: str, text: str, text_type: int
  26. ) -> List[Chunk]:
  27. flag = await self.contents_processor.insert_content(doc_id, text, text_type)
  28. if not flag:
  29. return []
  30. else:
  31. raw_chunks = await self.chunk(text, text_type)
  32. if not raw_chunks:
  33. await self.contents_processor.update_content_status(
  34. doc_id=doc_id,
  35. ori_status=self.INIT_STATUS,
  36. new_status=self.FAILED_STATUS,
  37. )
  38. return []
  39. await self.contents_processor.update_content_status(
  40. doc_id=doc_id,
  41. ori_status=self.INIT_STATUS,
  42. new_status=self.PROCESSING_STATUS,
  43. )
  44. return raw_chunks
  45. async def process_each_chunk(self, chunk: Chunk):
  46. # insert
  47. flag = await self.content_chunk_processor.insert_chunk(chunk)
  48. if not flag:
  49. print("插入文本失败")
  50. return
  51. acquire_lock = await self.content_chunk_processor.update_chunk_status(
  52. doc_id=chunk.doc_id,
  53. chunk_id=chunk.chunk_id,
  54. ori_status=self.INIT_STATUS,
  55. new_status=self.PROCESSING_STATUS,
  56. )
  57. if not acquire_lock:
  58. print("抢占文本分块锁失败")
  59. return
  60. completion = await self.classifier.classify_chunk(chunk)
  61. if not completion:
  62. await self.content_chunk_processor.update_chunk_status(
  63. doc_id=chunk.doc_id,
  64. chunk_id=chunk.chunk_id,
  65. ori_status=self.PROCESSING_STATUS,
  66. new_status=self.FAILED_STATUS,
  67. )
  68. print("从deepseek获取信息失败")
  69. return
  70. update_flag = await self.content_chunk_processor.set_chunk_result(
  71. chunk=completion,
  72. ori_status=self.PROCESSING_STATUS,
  73. new_status=self.FINISHED_STATUS,
  74. )
  75. if not update_flag:
  76. await self.content_chunk_processor.update_chunk_status(
  77. doc_id=chunk.doc_id,
  78. chunk_id=chunk.chunk_id,
  79. ori_status=self.PROCESSING_STATUS,
  80. new_status=self.FAILED_STATUS,
  81. )
  82. return
  83. await self.save_to_milvus(completion)
  84. async def save_to_milvus(self, chunk: Chunk):
  85. """
  86. :param chunk: each single chunk
  87. :return:
  88. """
  89. # 抢锁
  90. acquire_lock = await self.content_chunk_processor.update_embedding_status(
  91. doc_id=chunk.doc_id,
  92. chunk_id=chunk.chunk_id,
  93. new_status=self.PROCESSING_STATUS,
  94. ori_status=self.INIT_STATUS,
  95. )
  96. if not acquire_lock:
  97. print(f"抢占-{chunk.doc_id}-{chunk.chunk_id}分块-embedding处理锁失败")
  98. return
  99. try:
  100. data = {
  101. "doc_id": chunk.doc_id,
  102. "chunk_id": chunk.chunk_id,
  103. "vector_text": await self.get_embedding_list(chunk.text),
  104. "vector_summary": await self.get_embedding_list(chunk.summary),
  105. "vector_questions": await self.get_embedding_list(
  106. ",".join(chunk.questions)
  107. ),
  108. "topic": chunk.topic,
  109. "domain": chunk.domain,
  110. "task_type": chunk.task_type,
  111. "summary": chunk.summary,
  112. "keywords": chunk.keywords,
  113. "entities": chunk.entities,
  114. "concepts": chunk.concepts,
  115. "questions": chunk.questions,
  116. "topic_purity": chunk.topic_purity,
  117. "tokens": chunk.tokens,
  118. }
  119. await async_insert_chunk(self.vector_pool, data)
  120. await self.content_chunk_processor.update_embedding_status(
  121. doc_id=chunk.doc_id,
  122. chunk_id=chunk.chunk_id,
  123. ori_status=self.PROCESSING_STATUS,
  124. new_status=self.FINISHED_STATUS,
  125. )
  126. except Exception as e:
  127. await self.content_chunk_processor.update_embedding_status(
  128. doc_id=chunk.doc_id,
  129. chunk_id=chunk.chunk_id,
  130. ori_status=self.PROCESSING_STATUS,
  131. new_status=self.FAILED_STATUS,
  132. )
  133. print(f"存入向量数据库失败", e)
  134. async def deal(self, data):
  135. text = data.get("text", "")
  136. text = text.strip()
  137. text_type = data.get("text_type", 1)
  138. if not text:
  139. return None
  140. self.init_processer()
  141. async def _process():
  142. chunks = await self.process_content(self.doc_id, text, text_type)
  143. if not chunks:
  144. return
  145. # # dev
  146. # for chunk in chunks:
  147. # await self.process_each_chunk(chunk)
  148. await run_tasks_with_asyncio_task_group(
  149. task_list=chunks,
  150. handler=self.process_each_chunk,
  151. description="处理单篇文章分块",
  152. unit="chunk",
  153. max_concurrency=10,
  154. )
  155. await self.contents_processor.update_content_status(
  156. doc_id=self.doc_id,
  157. ori_status=self.PROCESSING_STATUS,
  158. new_status=self.FINISHED_STATUS,
  159. )
  160. asyncio.create_task(_process())
  161. return self.doc_id