chunk_task.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. import asyncio
  2. from typing import List
  3. from applications.api import get_basic_embedding
  4. from applications.utils.async_utils import run_tasks_with_asyncio_task_group
  5. from applications.utils.chunks import LLMClassifier, TopicAwarePackerV2
  6. from applications.utils.milvus import async_insert_chunk
  7. from applications.utils.mysql import ContentChunks, Contents
  8. from applications.config import Chunk, DEFAULT_MODEL
  9. from applications.config import ELASTIC_SEARCH_INDEX
  10. class ChunkEmbeddingTask(TopicAwarePackerV2):
  11. def __init__(self, doc_id, resource):
  12. super().__init__(doc_id)
  13. self.chunk_manager = None
  14. self.content_manager = None
  15. self.mysql_client = resource.mysql_client
  16. self.milvus_client = resource.milvus_client
  17. self.es_client = resource.es_client
  18. self.classifier = LLMClassifier()
  19. @staticmethod
  20. async def get_embedding_list(text: str) -> List:
  21. return await get_basic_embedding(text=text, model=DEFAULT_MODEL)
  22. def init_processer(self):
  23. self.content_manager = Contents(self.mysql_client)
  24. self.chunk_manager = ContentChunks(self.mysql_client)
  25. async def _chunk_each_content(
  26. self,
  27. doc_id: str,
  28. text: str,
  29. text_type: int,
  30. title: str,
  31. dataset_id: int,
  32. re_chunk: bool,
  33. ) -> List[Chunk]:
  34. if re_chunk:
  35. await self.content_manager.update_content_info(
  36. doc_id=doc_id,
  37. text=text,
  38. text_type=text_type,
  39. title=title,
  40. dataset_id=dataset_id,
  41. )
  42. flag = True
  43. else:
  44. flag = await self.content_manager.insert_content(
  45. doc_id, text, text_type, title, dataset_id
  46. )
  47. if not flag:
  48. return []
  49. else:
  50. raw_chunks = await self.chunk(text, text_type, dataset_id)
  51. if not raw_chunks:
  52. await self.content_manager.update_content_status(
  53. doc_id=doc_id,
  54. ori_status=self.INIT_STATUS,
  55. new_status=self.FAILED_STATUS,
  56. )
  57. return []
  58. await self.content_manager.update_content_status(
  59. doc_id=doc_id,
  60. ori_status=self.INIT_STATUS,
  61. new_status=self.PROCESSING_STATUS,
  62. )
  63. return raw_chunks
  64. async def insert_into_es(self, milvus_id, chunk: Chunk) -> int:
  65. docs = [
  66. {
  67. "_index": ELASTIC_SEARCH_INDEX,
  68. "_id": milvus_id,
  69. "_source": {
  70. "milvus_id": milvus_id,
  71. "doc_id": chunk.doc_id,
  72. "dataset_id": chunk.dataset_id,
  73. "chunk_id": chunk.chunk_id,
  74. "topic": chunk.topic,
  75. "domain": chunk.domain,
  76. "task_type": chunk.task_type,
  77. "text_type": chunk.text_type,
  78. "keywords": chunk.keywords,
  79. "concepts": chunk.concepts,
  80. "entities": chunk.entities,
  81. "status": chunk.status,
  82. },
  83. }
  84. ]
  85. resp = await self.es_client.bulk_insert(docs)
  86. return resp["success"]
  87. async def save_each_chunk(self, chunk: Chunk):
  88. # insert
  89. flag = await self.chunk_manager.insert_chunk(chunk)
  90. if not flag:
  91. print("插入文本失败")
  92. return
  93. acquire_lock = await self.chunk_manager.update_chunk_status(
  94. doc_id=chunk.doc_id,
  95. chunk_id=chunk.chunk_id,
  96. ori_status=self.INIT_STATUS,
  97. new_status=self.PROCESSING_STATUS,
  98. )
  99. if not acquire_lock:
  100. print("抢占文本分块锁失败")
  101. return
  102. completion = await self.classifier.classify_chunk(chunk)
  103. if not completion:
  104. await self.chunk_manager.update_chunk_status(
  105. doc_id=chunk.doc_id,
  106. chunk_id=chunk.chunk_id,
  107. ori_status=self.PROCESSING_STATUS,
  108. new_status=self.FAILED_STATUS,
  109. )
  110. print("从deepseek获取信息失败")
  111. return
  112. update_flag = await self.chunk_manager.set_chunk_result(
  113. chunk=completion,
  114. ori_status=self.PROCESSING_STATUS,
  115. new_status=self.FINISHED_STATUS,
  116. )
  117. if not update_flag:
  118. await self.chunk_manager.update_chunk_status(
  119. doc_id=chunk.doc_id,
  120. chunk_id=chunk.chunk_id,
  121. ori_status=self.PROCESSING_STATUS,
  122. new_status=self.FAILED_STATUS,
  123. )
  124. return
  125. milvus_id = await self.save_to_milvus(completion)
  126. if not milvus_id:
  127. return
  128. # 存储到 es 中
  129. # acquire_lock
  130. acquire_es_lock = await self.chunk_manager.update_es_status(
  131. doc_id=chunk.doc_id,
  132. chunk_id=chunk.chunk_id,
  133. ori_status=self.INIT_STATUS,
  134. new_status=self.PROCESSING_STATUS,
  135. )
  136. if not acquire_es_lock:
  137. print(f"获取 es Lock Fail: {chunk.doc_id}--{chunk.chunk_id}")
  138. return
  139. insert_rows = await self.insert_into_es(milvus_id, completion)
  140. final_status = self.FINISHED_STATUS if insert_rows else self.FAILED_STATUS
  141. await self.chunk_manager.update_es_status(
  142. doc_id=chunk.doc_id,
  143. chunk_id=chunk.chunk_id,
  144. ori_status=self.PROCESSING_STATUS,
  145. new_status=final_status,
  146. )
  147. async def save_to_milvus(self, chunk: Chunk):
  148. """
  149. :param chunk: each single chunk
  150. :return:
  151. """
  152. # 抢锁
  153. acquire_lock = await self.chunk_manager.update_embedding_status(
  154. doc_id=chunk.doc_id,
  155. chunk_id=chunk.chunk_id,
  156. new_status=self.PROCESSING_STATUS,
  157. ori_status=self.INIT_STATUS,
  158. )
  159. if not acquire_lock:
  160. print(f"抢占-{chunk.doc_id}-{chunk.chunk_id}分块-embedding处理锁失败")
  161. return None
  162. try:
  163. data = {
  164. "doc_id": chunk.doc_id,
  165. "chunk_id": chunk.chunk_id,
  166. "vector_text": await self.get_embedding_list(chunk.text),
  167. "vector_summary": await self.get_embedding_list(chunk.summary),
  168. "vector_questions": await self.get_embedding_list(
  169. ",".join(chunk.questions)
  170. ),
  171. }
  172. resp = await async_insert_chunk(self.milvus_client, data)
  173. if not resp:
  174. await self.chunk_manager.update_embedding_status(
  175. doc_id=chunk.doc_id,
  176. chunk_id=chunk.chunk_id,
  177. ori_status=self.PROCESSING_STATUS,
  178. new_status=self.FAILED_STATUS,
  179. )
  180. return None
  181. await self.chunk_manager.update_embedding_status(
  182. doc_id=chunk.doc_id,
  183. chunk_id=chunk.chunk_id,
  184. ori_status=self.PROCESSING_STATUS,
  185. new_status=self.FINISHED_STATUS,
  186. )
  187. milvus_id = resp[0]
  188. return milvus_id
  189. except Exception as e:
  190. await self.chunk_manager.update_embedding_status(
  191. doc_id=chunk.doc_id,
  192. chunk_id=chunk.chunk_id,
  193. ori_status=self.PROCESSING_STATUS,
  194. new_status=self.FAILED_STATUS,
  195. )
  196. print(f"存入向量数据库失败", e)
  197. return None
  198. async def deal(self, data):
  199. text = data.get("text", "")
  200. title = data.get("title", "")
  201. text, title = text.strip(), title.strip()
  202. text_type = data.get("text_type", 1)
  203. dataset_id = data.get("dataset_id", 0) # 默认知识库 id 为 0
  204. re_chunk = data.get("re_chunk", False)
  205. if not text:
  206. return None
  207. self.init_processer()
  208. async def _process():
  209. chunks = await self._chunk_each_content(
  210. self.doc_id, text, text_type, title, dataset_id, re_chunk
  211. )
  212. if not chunks:
  213. return
  214. # # dev
  215. # for chunk in chunks:
  216. # await self.save_each_chunk(chunk)
  217. await run_tasks_with_asyncio_task_group(
  218. task_list=chunks,
  219. handler=self.save_each_chunk,
  220. description="处理单篇文章分块",
  221. unit="chunk",
  222. max_concurrency=10,
  223. )
  224. await self.content_manager.update_content_status(
  225. doc_id=self.doc_id,
  226. ori_status=self.PROCESSING_STATUS,
  227. new_status=self.FINISHED_STATUS,
  228. )
  229. asyncio.create_task(_process())
  230. return self.doc_id