chunk_task.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import asyncio
  2. import uuid
  3. from typing import List
  4. from applications.api import get_basic_embedding
  5. from applications.utils.async_utils import run_tasks_with_asyncio_task_group
  6. from applications.utils.mysql import ContentChunks, Contents
  7. from applications.utils.chunks import TopicAwareChunker, LLMClassifier
  8. from applications.utils.milvus import async_insert_chunk
  9. from applications.config import Chunk, ChunkerConfig, DEFAULT_MODEL
  10. class ChunkEmbeddingTask(TopicAwareChunker):
  11. def __init__(self, mysql_pool, vector_pool, cfg: ChunkerConfig, doc_id):
  12. super().__init__(cfg, doc_id)
  13. self.content_chunk_processor = None
  14. self.contents_processor = None
  15. self.mysql_pool = mysql_pool
  16. self.vector_pool = vector_pool
  17. self.classifier = LLMClassifier()
  18. @staticmethod
  19. async def get_embedding_list(text: str) -> List:
  20. return await get_basic_embedding(text=text, model=DEFAULT_MODEL, dev=True)
  21. def init_processer(self):
  22. self.contents_processor = Contents(self.mysql_pool)
  23. self.content_chunk_processor = ContentChunks(self.mysql_pool)
  24. async def process_content(self, doc_id, text) -> List[Chunk]:
  25. flag = await self.contents_processor.insert_content(doc_id, text)
  26. if not flag:
  27. return []
  28. else:
  29. raw_chunks = await self.chunk(text)
  30. if not raw_chunks:
  31. await self.contents_processor.update_content_status(
  32. doc_id=doc_id,
  33. ori_status=self.INIT_STATUS,
  34. new_status=self.FAILED_STATUS,
  35. )
  36. return []
  37. affected_rows = await self.contents_processor.update_content_status(
  38. doc_id=doc_id,
  39. ori_status=self.INIT_STATUS,
  40. new_status=self.PROCESSING_STATUS,
  41. )
  42. return raw_chunks
  43. async def process_each_chunk(self, chunk: Chunk):
  44. # insert
  45. flag = await self.content_chunk_processor.insert_chunk(chunk)
  46. if not flag:
  47. print("插入文本失败")
  48. return
  49. acquire_lock = await self.content_chunk_processor.update_chunk_status(
  50. doc_id=chunk.doc_id,
  51. chunk_id=chunk.chunk_id,
  52. ori_status=self.INIT_STATUS,
  53. new_status=self.PROCESSING_STATUS,
  54. )
  55. if not acquire_lock:
  56. print("抢占文本分块锁失败")
  57. return
  58. completion = await self.classifier.classify_chunk(chunk)
  59. if not completion:
  60. await self.content_chunk_processor.update_chunk_status(
  61. doc_id=chunk.doc_id,
  62. chunk_id=chunk.chunk_id,
  63. ori_status=self.PROCESSING_STATUS,
  64. new_status=self.FAILED_STATUS,
  65. )
  66. print("从deepseek获取信息失败")
  67. return
  68. update_flag = await self.content_chunk_processor.set_chunk_result(
  69. chunk=completion,
  70. ori_status=self.PROCESSING_STATUS,
  71. new_status=self.FINISHED_STATUS,
  72. )
  73. if not update_flag:
  74. await self.content_chunk_processor.update_chunk_status(
  75. doc_id=chunk.doc_id,
  76. chunk_id=chunk.chunk_id,
  77. ori_status=self.PROCESSING_STATUS,
  78. new_status=self.FAILED_STATUS,
  79. )
  80. return
  81. await self.save_to_milvus(completion)
  82. async def save_to_milvus(self, chunk: Chunk):
  83. """
  84. :param chunk: each single chunk
  85. :return:
  86. """
  87. # 抢锁
  88. acquire_lock = await self.content_chunk_processor.update_embedding_status(
  89. doc_id=chunk.doc_id,
  90. chunk_id=chunk.chunk_id,
  91. new_status=self.PROCESSING_STATUS,
  92. ori_status=self.INIT_STATUS,
  93. )
  94. if not acquire_lock:
  95. print(f"抢占-{chunk.doc_id}-{chunk.chunk_id}分块-embedding处理锁失败")
  96. return
  97. try:
  98. data = {
  99. "doc_id": chunk.doc_id,
  100. "chunk_id": chunk.chunk_id,
  101. "vector_text": await self.get_embedding_list(chunk.text),
  102. "vector_summary": await self.get_embedding_list(chunk.summary),
  103. "vector_questions": await self.get_embedding_list(
  104. ",".join(chunk.questions)
  105. ),
  106. "topic": chunk.topic,
  107. "domain": chunk.domain,
  108. "task_type": chunk.task_type,
  109. "summary": chunk.summary,
  110. "keywords": chunk.keywords,
  111. "entities": chunk.entities,
  112. "concepts": chunk.concepts,
  113. "questions": chunk.questions,
  114. "topic_purity": chunk.topic_purity,
  115. "tokens": chunk.tokens,
  116. }
  117. await async_insert_chunk(self.vector_pool, data)
  118. await self.content_chunk_processor.update_embedding_status(
  119. doc_id=chunk.doc_id,
  120. chunk_id=chunk.chunk_id,
  121. ori_status=self.PROCESSING_STATUS,
  122. new_status=self.FINISHED_STATUS,
  123. )
  124. except Exception as e:
  125. await self.content_chunk_processor.update_embedding_status(
  126. doc_id=chunk.doc_id,
  127. chunk_id=chunk.chunk_id,
  128. ori_status=self.PROCESSING_STATUS,
  129. new_status=self.FAILED_STATUS,
  130. )
  131. print(f"存入向量数据库失败", e)
  132. async def deal(self, data):
  133. text = data.get("text", "")
  134. text = text.strip()
  135. if not text:
  136. return None
  137. self.init_processer()
  138. async def _process():
  139. chunks = await self.process_content(self.doc_id, text)
  140. if not chunks:
  141. return
  142. # # dev
  143. # for chunk in chunks:
  144. # await self.process_each_chunk(chunk)
  145. await run_tasks_with_asyncio_task_group(
  146. task_list=chunks,
  147. handler=self.process_each_chunk,
  148. description="处理单篇文章分块",
  149. unit="chunk",
  150. max_concurrency=10,
  151. )
  152. await self.contents_processor.update_content_status(
  153. doc_id=self.doc_id,
  154. ori_status=self.PROCESSING_STATUS,
  155. new_status=self.FINISHED_STATUS,
  156. )
  157. asyncio.create_task(_process())
  158. return self.doc_id