chunk_task.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import asyncio
  2. from typing import List
  3. from applications.api import get_basic_embedding
  4. from applications.utils.async_utils import run_tasks_with_asyncio_task_group
  5. from applications.utils.chunks import TopicAwareChunker, LLMClassifier
  6. from applications.utils.milvus import async_insert_chunk
  7. from applications.utils.mysql import ContentChunks, Contents
  8. from applications.config import Chunk, ChunkerConfig, DEFAULT_MODEL
  9. from applications.config import ELASTIC_SEARCH_INDEX
  10. class ChunkEmbeddingTask(TopicAwareChunker):
  11. def __init__(self, cfg: ChunkerConfig, doc_id, resource):
  12. super().__init__(cfg, doc_id)
  13. self.chunk_manager = None
  14. self.content_manager = None
  15. self.mysql_client = resource.mysql_client
  16. self.milvus_client = resource.milvus_client
  17. self.es_client = resource.es_client
  18. self.classifier = LLMClassifier()
  19. @staticmethod
  20. async def get_embedding_list(text: str) -> List:
  21. return await get_basic_embedding(text=text, model=DEFAULT_MODEL)
  22. def init_processer(self):
  23. self.content_manager = Contents(self.mysql_client)
  24. self.chunk_manager = ContentChunks(self.mysql_client)
  25. async def _chunk_each_content(
  26. self, doc_id: str, text: str, text_type: int, title: str, dataset_id: int
  27. ) -> List[Chunk]:
  28. flag = await self.content_manager.insert_content(
  29. doc_id, text, text_type, title, dataset_id
  30. )
  31. if not flag:
  32. return []
  33. else:
  34. raw_chunks = await self.chunk(text, text_type, dataset_id)
  35. if not raw_chunks:
  36. await self.content_manager.update_content_status(
  37. doc_id=doc_id,
  38. ori_status=self.INIT_STATUS,
  39. new_status=self.FAILED_STATUS,
  40. )
  41. return []
  42. await self.content_manager.update_content_status(
  43. doc_id=doc_id,
  44. ori_status=self.INIT_STATUS,
  45. new_status=self.PROCESSING_STATUS,
  46. )
  47. return raw_chunks
  48. async def insert_into_es(self, milvus_id, chunk: Chunk) -> int:
  49. docs = [
  50. {
  51. "_index": ELASTIC_SEARCH_INDEX,
  52. "_id": milvus_id,
  53. "_source": {
  54. "milvus_id": milvus_id,
  55. "doc_id": chunk.doc_id,
  56. "dataset_id": chunk.dataset_id,
  57. "chunk_id": chunk.chunk_id,
  58. "topic": chunk.topic,
  59. "domain": chunk.domain,
  60. "task_type": chunk.task_type,
  61. "text_type": chunk.text_type,
  62. "keywords": chunk.keywords,
  63. "concepts": chunk.concepts,
  64. "entities": chunk.entities,
  65. "status": chunk.status,
  66. },
  67. }
  68. ]
  69. resp = await self.es_client.bulk_insert(docs)
  70. return resp["success"]
  71. async def save_each_chunk(self, chunk: Chunk):
  72. # insert
  73. flag = await self.chunk_manager.insert_chunk(chunk)
  74. if not flag:
  75. print("插入文本失败")
  76. return
  77. acquire_lock = await self.chunk_manager.update_chunk_status(
  78. doc_id=chunk.doc_id,
  79. chunk_id=chunk.chunk_id,
  80. ori_status=self.INIT_STATUS,
  81. new_status=self.PROCESSING_STATUS,
  82. )
  83. if not acquire_lock:
  84. print("抢占文本分块锁失败")
  85. return
  86. completion = await self.classifier.classify_chunk(chunk)
  87. if not completion:
  88. await self.chunk_manager.update_chunk_status(
  89. doc_id=chunk.doc_id,
  90. chunk_id=chunk.chunk_id,
  91. ori_status=self.PROCESSING_STATUS,
  92. new_status=self.FAILED_STATUS,
  93. )
  94. print("从deepseek获取信息失败")
  95. return
  96. update_flag = await self.chunk_manager.set_chunk_result(
  97. chunk=completion,
  98. ori_status=self.PROCESSING_STATUS,
  99. new_status=self.FINISHED_STATUS,
  100. )
  101. if not update_flag:
  102. await self.chunk_manager.update_chunk_status(
  103. doc_id=chunk.doc_id,
  104. chunk_id=chunk.chunk_id,
  105. ori_status=self.PROCESSING_STATUS,
  106. new_status=self.FAILED_STATUS,
  107. )
  108. return
  109. milvus_id = await self.save_to_milvus(completion)
  110. if not milvus_id:
  111. return
  112. # 存储到 es 中
  113. # acquire_lock
  114. acquire_es_lock = await self.chunk_manager.update_es_status(
  115. doc_id=chunk.doc_id,
  116. chunk_id=chunk.chunk_id,
  117. ori_status=self.INIT_STATUS,
  118. new_status=self.PROCESSING_STATUS,
  119. )
  120. if not acquire_es_lock:
  121. print(f"获取 es Lock Fail: {chunk.doc_id}--{chunk.chunk_id}")
  122. return
  123. insert_rows = await self.insert_into_es(milvus_id, completion)
  124. final_status = self.FINISHED_STATUS if insert_rows else self.FAILED_STATUS
  125. await self.chunk_manager.update_es_status(
  126. doc_id=chunk.doc_id,
  127. chunk_id=chunk.chunk_id,
  128. ori_status=self.PROCESSING_STATUS,
  129. new_status=final_status,
  130. )
  131. async def save_to_milvus(self, chunk: Chunk):
  132. """
  133. :param chunk: each single chunk
  134. :return:
  135. """
  136. # 抢锁
  137. acquire_lock = await self.chunk_manager.update_embedding_status(
  138. doc_id=chunk.doc_id,
  139. chunk_id=chunk.chunk_id,
  140. new_status=self.PROCESSING_STATUS,
  141. ori_status=self.INIT_STATUS,
  142. )
  143. if not acquire_lock:
  144. print(f"抢占-{chunk.doc_id}-{chunk.chunk_id}分块-embedding处理锁失败")
  145. return None
  146. try:
  147. data = {
  148. "doc_id": chunk.doc_id,
  149. "chunk_id": chunk.chunk_id,
  150. "vector_text": await self.get_embedding_list(chunk.text),
  151. "vector_summary": await self.get_embedding_list(chunk.summary),
  152. "vector_questions": await self.get_embedding_list(
  153. ",".join(chunk.questions)
  154. ),
  155. }
  156. resp = await async_insert_chunk(self.milvus_client, data)
  157. if not resp:
  158. await self.chunk_manager.update_embedding_status(
  159. doc_id=chunk.doc_id,
  160. chunk_id=chunk.chunk_id,
  161. ori_status=self.PROCESSING_STATUS,
  162. new_status=self.FAILED_STATUS,
  163. )
  164. return None
  165. await self.chunk_manager.update_embedding_status(
  166. doc_id=chunk.doc_id,
  167. chunk_id=chunk.chunk_id,
  168. ori_status=self.PROCESSING_STATUS,
  169. new_status=self.FINISHED_STATUS,
  170. )
  171. milvus_id = resp[0]
  172. return milvus_id
  173. except Exception as e:
  174. await self.chunk_manager.update_embedding_status(
  175. doc_id=chunk.doc_id,
  176. chunk_id=chunk.chunk_id,
  177. ori_status=self.PROCESSING_STATUS,
  178. new_status=self.FAILED_STATUS,
  179. )
  180. print(f"存入向量数据库失败", e)
  181. return None
  182. async def deal(self, data):
  183. text = data.get("text", "")
  184. title = data.get("title", "")
  185. text, title = text.strip(), title.strip()
  186. text_type = data.get("text_type", 1)
  187. dataset_id = data.get("dataset_id", 0) # 默认知识库 id 为 0
  188. if not text:
  189. return None
  190. self.init_processer()
  191. async def _process():
  192. chunks = await self._chunk_each_content(
  193. self.doc_id, text, text_type, title, dataset_id
  194. )
  195. if not chunks:
  196. return
  197. # # dev
  198. # for chunk in chunks:
  199. # await self.save_each_chunk(chunk)
  200. await run_tasks_with_asyncio_task_group(
  201. task_list=chunks,
  202. handler=self.save_each_chunk,
  203. description="处理单篇文章分块",
  204. unit="chunk",
  205. max_concurrency=10,
  206. )
  207. await self.content_manager.update_content_status(
  208. doc_id=self.doc_id,
  209. ori_status=self.PROCESSING_STATUS,
  210. new_status=self.FINISHED_STATUS,
  211. )
  212. asyncio.create_task(_process())
  213. return self.doc_id