chunk_task.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. import asyncio
  2. from typing import List
  3. from applications.api import get_basic_embedding
  4. from applications.utils.async_utils import run_tasks_with_asyncio_task_group
  5. from applications.utils.chunks import LLMClassifier, TopicAwarePackerV2
  6. from applications.utils.milvus import async_insert_chunk
  7. from applications.utils.mysql import ContentChunks, Contents
  8. from applications.utils.nlp import num_tokens
  9. from applications.config import Chunk, DEFAULT_MODEL
  10. from applications.config import ELASTIC_SEARCH_INDEX
  11. class ChunkEmbeddingTask(TopicAwarePackerV2):
  12. def __init__(self, doc_id, resource):
  13. super().__init__(doc_id)
  14. self.chunk_manager = None
  15. self.content_manager = None
  16. self.mysql_client = resource.mysql_client
  17. self.milvus_client = resource.milvus_client
  18. self.es_client = resource.es_client
  19. self.classifier = LLMClassifier()
  20. @staticmethod
  21. async def get_embedding_list(text: str) -> List:
  22. return await get_basic_embedding(text=text, model=DEFAULT_MODEL)
  23. def init_processer(self):
  24. self.content_manager = Contents(self.mysql_client)
  25. self.chunk_manager = ContentChunks(self.mysql_client)
  26. async def _chunk_each_content(
  27. self,
  28. doc_id: str,
  29. data: dict
  30. ) -> List[Chunk]:
  31. title, text = data.get("title", "").strip(), data["text"].strip()
  32. text_type = data.get("text_type", 1)
  33. dataset_id = data.get("dataset_id", 0) # 默认知识库 id 为 0
  34. re_chunk = data.get("re_chunk", False)
  35. dont_chunk = data.get("dont_chunk", False)
  36. if re_chunk:
  37. await self.content_manager.update_content_info(
  38. doc_id=doc_id,
  39. text=text,
  40. text_type=text_type,
  41. title=title,
  42. dataset_id=dataset_id,
  43. )
  44. flag = True
  45. else:
  46. flag = await self.content_manager.insert_content(
  47. doc_id, text, text_type, title, dataset_id
  48. )
  49. if not flag:
  50. return []
  51. else:
  52. raw_chunks = await self.chunk(text, text_type, dataset_id, dont_chunk)
  53. if not raw_chunks:
  54. await self.content_manager.update_content_status(
  55. doc_id=doc_id,
  56. ori_status=self.INIT_STATUS,
  57. new_status=self.FAILED_STATUS,
  58. )
  59. return []
  60. await self.content_manager.update_content_status(
  61. doc_id=doc_id,
  62. ori_status=self.INIT_STATUS,
  63. new_status=self.PROCESSING_STATUS,
  64. )
  65. return raw_chunks
  66. async def insert_into_es(self, milvus_id, chunk: Chunk) -> int:
  67. docs = [
  68. {
  69. "_index": ELASTIC_SEARCH_INDEX,
  70. "_id": milvus_id,
  71. "_source": {
  72. "milvus_id": milvus_id,
  73. "doc_id": chunk.doc_id,
  74. "dataset_id": chunk.dataset_id,
  75. "chunk_id": chunk.chunk_id,
  76. "topic": chunk.topic,
  77. "domain": chunk.domain,
  78. "task_type": chunk.task_type,
  79. "text_type": chunk.text_type,
  80. "keywords": chunk.keywords,
  81. "concepts": chunk.concepts,
  82. "entities": chunk.entities,
  83. "status": chunk.status,
  84. },
  85. }
  86. ]
  87. resp = await self.es_client.bulk_insert(docs)
  88. return resp["success"]
  89. async def save_each_chunk(self, chunk: Chunk):
  90. # insert
  91. flag = await self.chunk_manager.insert_chunk(chunk)
  92. if not flag:
  93. print("插入文本失败")
  94. return
  95. acquire_lock = await self.chunk_manager.update_chunk_status(
  96. doc_id=chunk.doc_id,
  97. chunk_id=chunk.chunk_id,
  98. ori_status=self.INIT_STATUS,
  99. new_status=self.PROCESSING_STATUS,
  100. )
  101. if not acquire_lock:
  102. print("抢占文本分块锁失败")
  103. return
  104. completion = await self.classifier.classify_chunk(chunk)
  105. if not completion:
  106. await self.chunk_manager.update_chunk_status(
  107. doc_id=chunk.doc_id,
  108. chunk_id=chunk.chunk_id,
  109. ori_status=self.PROCESSING_STATUS,
  110. new_status=self.FAILED_STATUS,
  111. )
  112. print("从deepseek获取信息失败")
  113. return
  114. update_flag = await self.chunk_manager.set_chunk_result(
  115. chunk=completion,
  116. ori_status=self.PROCESSING_STATUS,
  117. new_status=self.FINISHED_STATUS,
  118. )
  119. if not update_flag:
  120. await self.chunk_manager.update_chunk_status(
  121. doc_id=chunk.doc_id,
  122. chunk_id=chunk.chunk_id,
  123. ori_status=self.PROCESSING_STATUS,
  124. new_status=self.FAILED_STATUS,
  125. )
  126. return
  127. milvus_id = await self.save_to_milvus(completion)
  128. if not milvus_id:
  129. return
  130. # 存储到 es 中
  131. # acquire_lock
  132. acquire_es_lock = await self.chunk_manager.update_es_status(
  133. doc_id=chunk.doc_id,
  134. chunk_id=chunk.chunk_id,
  135. ori_status=self.INIT_STATUS,
  136. new_status=self.PROCESSING_STATUS,
  137. )
  138. if not acquire_es_lock:
  139. print(f"获取 es Lock Fail: {chunk.doc_id}--{chunk.chunk_id}")
  140. return
  141. insert_rows = await self.insert_into_es(milvus_id, completion)
  142. final_status = self.FINISHED_STATUS if insert_rows else self.FAILED_STATUS
  143. await self.chunk_manager.update_es_status(
  144. doc_id=chunk.doc_id,
  145. chunk_id=chunk.chunk_id,
  146. ori_status=self.PROCESSING_STATUS,
  147. new_status=final_status,
  148. )
  149. async def save_to_milvus(self, chunk: Chunk):
  150. """
  151. :param chunk: each single chunk
  152. :return:
  153. """
  154. # 抢锁
  155. acquire_lock = await self.chunk_manager.update_embedding_status(
  156. doc_id=chunk.doc_id,
  157. chunk_id=chunk.chunk_id,
  158. new_status=self.PROCESSING_STATUS,
  159. ori_status=self.INIT_STATUS,
  160. )
  161. if not acquire_lock:
  162. print(f"抢占-{chunk.doc_id}-{chunk.chunk_id}分块-embedding处理锁失败")
  163. return None
  164. try:
  165. data = {
  166. "doc_id": chunk.doc_id,
  167. "chunk_id": chunk.chunk_id,
  168. "vector_text": await self.get_embedding_list(chunk.text),
  169. "vector_summary": await self.get_embedding_list(chunk.summary),
  170. "vector_questions": await self.get_embedding_list(
  171. ",".join(chunk.questions)
  172. ),
  173. }
  174. resp = await async_insert_chunk(self.milvus_client, data)
  175. if not resp:
  176. await self.chunk_manager.update_embedding_status(
  177. doc_id=chunk.doc_id,
  178. chunk_id=chunk.chunk_id,
  179. ori_status=self.PROCESSING_STATUS,
  180. new_status=self.FAILED_STATUS,
  181. )
  182. return None
  183. await self.chunk_manager.update_embedding_status(
  184. doc_id=chunk.doc_id,
  185. chunk_id=chunk.chunk_id,
  186. ori_status=self.PROCESSING_STATUS,
  187. new_status=self.FINISHED_STATUS,
  188. )
  189. milvus_id = resp[0]
  190. return milvus_id
  191. except Exception as e:
  192. await self.chunk_manager.update_embedding_status(
  193. doc_id=chunk.doc_id,
  194. chunk_id=chunk.chunk_id,
  195. ori_status=self.PROCESSING_STATUS,
  196. new_status=self.FAILED_STATUS,
  197. )
  198. print(f"存入向量数据库失败", e)
  199. return None
  200. async def deal(self, data):
  201. text = data.get("text", "")
  202. dont_chunk = data.get("dont_chunk", False)
  203. # 如果无需分块,判断text 长度
  204. if dont_chunk and num_tokens(text) >= self.max_tokens:
  205. return {
  206. "error": "文档超多模型支持的最大吞吐量"
  207. }
  208. self.init_processer()
  209. async def _process():
  210. chunks = await self._chunk_each_content(self.doc_id, data)
  211. if not chunks:
  212. return
  213. # # dev
  214. # for chunk in chunks:
  215. # await self.save_each_chunk(chunk)
  216. await run_tasks_with_asyncio_task_group(
  217. task_list=chunks,
  218. handler=self.save_each_chunk,
  219. description="处理单篇文章分块",
  220. unit="chunk",
  221. max_concurrency=20,
  222. )
  223. await self.content_manager.update_content_status(
  224. doc_id=self.doc_id,
  225. ori_status=self.PROCESSING_STATUS,
  226. new_status=self.FINISHED_STATUS,
  227. )
  228. asyncio.create_task(_process())
  229. return self.doc_id