chunk_task.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. import asyncio
  2. from typing import List
  3. from applications.api import get_basic_embedding
  4. from applications.utils.async_utils import run_tasks_with_asyncio_task_group
  5. from applications.utils.chunks import LLMClassifier, TopicAwarePackerV2
  6. from applications.utils.milvus import async_insert_chunk
  7. from applications.utils.mysql import ContentChunks, Contents
  8. from applications.config import Chunk, DEFAULT_MODEL
  9. from applications.config import ELASTIC_SEARCH_INDEX
  10. class ChunkEmbeddingTask(TopicAwarePackerV2):
  11. def __init__(self, doc_id, resource):
  12. super().__init__(doc_id)
  13. self.chunk_manager = None
  14. self.content_manager = None
  15. self.mysql_client = resource.mysql_client
  16. self.milvus_client = resource.milvus_client
  17. self.es_client = resource.es_client
  18. self.classifier = LLMClassifier()
  19. @staticmethod
  20. async def get_embedding_list(text: str) -> List:
  21. return await get_basic_embedding(text=text, model=DEFAULT_MODEL)
  22. def init_processer(self):
  23. self.content_manager = Contents(self.mysql_client)
  24. self.chunk_manager = ContentChunks(self.mysql_client)
  25. async def _chunk_each_content(
  26. self,
  27. doc_id: str,
  28. text: str,
  29. text_type: int,
  30. title: str,
  31. dataset_id: int,
  32. re_chunk: bool,
  33. ) -> List[Chunk]:
  34. if re_chunk:
  35. flag = await self.content_manager.update_content_info(
  36. doc_id=doc_id,
  37. text=text,
  38. text_type=text_type,
  39. title=title,
  40. dataset_id=dataset_id,
  41. )
  42. else:
  43. flag = await self.content_manager.insert_content(
  44. doc_id, text, text_type, title, dataset_id
  45. )
  46. if not flag:
  47. return []
  48. else:
  49. raw_chunks = await self.chunk(text, text_type, dataset_id)
  50. if not raw_chunks:
  51. await self.content_manager.update_content_status(
  52. doc_id=doc_id,
  53. ori_status=self.INIT_STATUS,
  54. new_status=self.FAILED_STATUS,
  55. )
  56. return []
  57. await self.content_manager.update_content_status(
  58. doc_id=doc_id,
  59. ori_status=self.INIT_STATUS,
  60. new_status=self.PROCESSING_STATUS,
  61. )
  62. return raw_chunks
  63. async def insert_into_es(self, milvus_id, chunk: Chunk) -> int:
  64. docs = [
  65. {
  66. "_index": ELASTIC_SEARCH_INDEX,
  67. "_id": milvus_id,
  68. "_source": {
  69. "milvus_id": milvus_id,
  70. "doc_id": chunk.doc_id,
  71. "dataset_id": chunk.dataset_id,
  72. "chunk_id": chunk.chunk_id,
  73. "topic": chunk.topic,
  74. "domain": chunk.domain,
  75. "task_type": chunk.task_type,
  76. "text_type": chunk.text_type,
  77. "keywords": chunk.keywords,
  78. "concepts": chunk.concepts,
  79. "entities": chunk.entities,
  80. "status": chunk.status,
  81. },
  82. }
  83. ]
  84. resp = await self.es_client.bulk_insert(docs)
  85. return resp["success"]
  86. async def save_each_chunk(self, chunk: Chunk):
  87. # insert
  88. flag = await self.chunk_manager.insert_chunk(chunk)
  89. if not flag:
  90. print("插入文本失败")
  91. return
  92. acquire_lock = await self.chunk_manager.update_chunk_status(
  93. doc_id=chunk.doc_id,
  94. chunk_id=chunk.chunk_id,
  95. ori_status=self.INIT_STATUS,
  96. new_status=self.PROCESSING_STATUS,
  97. )
  98. if not acquire_lock:
  99. print("抢占文本分块锁失败")
  100. return
  101. completion = await self.classifier.classify_chunk(chunk)
  102. if not completion:
  103. await self.chunk_manager.update_chunk_status(
  104. doc_id=chunk.doc_id,
  105. chunk_id=chunk.chunk_id,
  106. ori_status=self.PROCESSING_STATUS,
  107. new_status=self.FAILED_STATUS,
  108. )
  109. print("从deepseek获取信息失败")
  110. return
  111. update_flag = await self.chunk_manager.set_chunk_result(
  112. chunk=completion,
  113. ori_status=self.PROCESSING_STATUS,
  114. new_status=self.FINISHED_STATUS,
  115. )
  116. if not update_flag:
  117. await self.chunk_manager.update_chunk_status(
  118. doc_id=chunk.doc_id,
  119. chunk_id=chunk.chunk_id,
  120. ori_status=self.PROCESSING_STATUS,
  121. new_status=self.FAILED_STATUS,
  122. )
  123. return
  124. milvus_id = await self.save_to_milvus(completion)
  125. if not milvus_id:
  126. return
  127. # 存储到 es 中
  128. # acquire_lock
  129. acquire_es_lock = await self.chunk_manager.update_es_status(
  130. doc_id=chunk.doc_id,
  131. chunk_id=chunk.chunk_id,
  132. ori_status=self.INIT_STATUS,
  133. new_status=self.PROCESSING_STATUS,
  134. )
  135. if not acquire_es_lock:
  136. print(f"获取 es Lock Fail: {chunk.doc_id}--{chunk.chunk_id}")
  137. return
  138. insert_rows = await self.insert_into_es(milvus_id, completion)
  139. final_status = self.FINISHED_STATUS if insert_rows else self.FAILED_STATUS
  140. await self.chunk_manager.update_es_status(
  141. doc_id=chunk.doc_id,
  142. chunk_id=chunk.chunk_id,
  143. ori_status=self.PROCESSING_STATUS,
  144. new_status=final_status,
  145. )
  146. async def save_to_milvus(self, chunk: Chunk):
  147. """
  148. :param chunk: each single chunk
  149. :return:
  150. """
  151. # 抢锁
  152. acquire_lock = await self.chunk_manager.update_embedding_status(
  153. doc_id=chunk.doc_id,
  154. chunk_id=chunk.chunk_id,
  155. new_status=self.PROCESSING_STATUS,
  156. ori_status=self.INIT_STATUS,
  157. )
  158. if not acquire_lock:
  159. print(f"抢占-{chunk.doc_id}-{chunk.chunk_id}分块-embedding处理锁失败")
  160. return None
  161. try:
  162. data = {
  163. "doc_id": chunk.doc_id,
  164. "chunk_id": chunk.chunk_id,
  165. "vector_text": await self.get_embedding_list(chunk.text),
  166. "vector_summary": await self.get_embedding_list(chunk.summary),
  167. "vector_questions": await self.get_embedding_list(
  168. ",".join(chunk.questions)
  169. ),
  170. }
  171. resp = await async_insert_chunk(self.milvus_client, data)
  172. if not resp:
  173. await self.chunk_manager.update_embedding_status(
  174. doc_id=chunk.doc_id,
  175. chunk_id=chunk.chunk_id,
  176. ori_status=self.PROCESSING_STATUS,
  177. new_status=self.FAILED_STATUS,
  178. )
  179. return None
  180. await self.chunk_manager.update_embedding_status(
  181. doc_id=chunk.doc_id,
  182. chunk_id=chunk.chunk_id,
  183. ori_status=self.PROCESSING_STATUS,
  184. new_status=self.FINISHED_STATUS,
  185. )
  186. milvus_id = resp[0]
  187. return milvus_id
  188. except Exception as e:
  189. await self.chunk_manager.update_embedding_status(
  190. doc_id=chunk.doc_id,
  191. chunk_id=chunk.chunk_id,
  192. ori_status=self.PROCESSING_STATUS,
  193. new_status=self.FAILED_STATUS,
  194. )
  195. print(f"存入向量数据库失败", e)
  196. return None
  197. async def deal(self, data):
  198. text = data.get("text", "")
  199. title = data.get("title", "")
  200. text, title = text.strip(), title.strip()
  201. text_type = data.get("text_type", 1)
  202. dataset_id = data.get("dataset_id", 0) # 默认知识库 id 为 0
  203. re_chunk = data.get("re_chunk", False)
  204. if not text:
  205. return None
  206. self.init_processer()
  207. async def _process():
  208. chunks = await self._chunk_each_content(
  209. self.doc_id, text, text_type, title, dataset_id, re_chunk
  210. )
  211. if not chunks:
  212. return
  213. # # dev
  214. # for chunk in chunks:
  215. # await self.save_each_chunk(chunk)
  216. await run_tasks_with_asyncio_task_group(
  217. task_list=chunks,
  218. handler=self.save_each_chunk,
  219. description="处理单篇文章分块",
  220. unit="chunk",
  221. max_concurrency=10,
  222. )
  223. await self.content_manager.update_content_status(
  224. doc_id=self.doc_id,
  225. ori_status=self.PROCESSING_STATUS,
  226. new_status=self.FINISHED_STATUS,
  227. )
  228. asyncio.create_task(_process())
  229. return self.doc_id