chunk_task.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import asyncio
  2. from typing import List
  3. from applications.api import get_basic_embedding
  4. from applications.utils.async_utils import run_tasks_with_asyncio_task_group
  5. from applications.utils.chunks import LLMClassifier, TopicAwarePackerV2
  6. from applications.utils.milvus import async_insert_chunk
  7. from applications.utils.mysql import ContentChunks, Contents
  8. from applications.config import Chunk, DEFAULT_MODEL
  9. from applications.config import ELASTIC_SEARCH_INDEX
  10. class ChunkEmbeddingTask(TopicAwarePackerV2):
  11. def __init__(self, doc_id, resource):
  12. super().__init__(doc_id)
  13. self.chunk_manager = None
  14. self.content_manager = None
  15. self.mysql_client = resource.mysql_client
  16. self.milvus_client = resource.milvus_client
  17. self.es_client = resource.es_client
  18. self.classifier = LLMClassifier()
  19. @staticmethod
  20. async def get_embedding_list(text: str) -> List:
  21. return await get_basic_embedding(text=text, model=DEFAULT_MODEL)
  22. def init_processer(self):
  23. self.content_manager = Contents(self.mysql_client)
  24. self.chunk_manager = ContentChunks(self.mysql_client)
  25. async def _chunk_each_content(
  26. self,
  27. doc_id: str,
  28. text: str,
  29. text_type: int,
  30. title: str,
  31. dataset_id: int,
  32. re_chunk: bool,
  33. ) -> List[Chunk]:
  34. if re_chunk:
  35. await self.content_manager.update_content_info(
  36. doc_id=doc_id,
  37. text=text,
  38. text_type=text_type,
  39. title=title,
  40. dataset_id=dataset_id,
  41. )
  42. flag = True
  43. else:
  44. flag = await self.content_manager.insert_content(
  45. doc_id, text, text_type, title, dataset_id
  46. )
  47. print(flag)
  48. if not flag:
  49. return []
  50. else:
  51. raw_chunks = await self.chunk(text, text_type, dataset_id)
  52. print(raw_chunks)
  53. if not raw_chunks:
  54. await self.content_manager.update_content_status(
  55. doc_id=doc_id,
  56. ori_status=self.INIT_STATUS,
  57. new_status=self.FAILED_STATUS,
  58. )
  59. return []
  60. await self.content_manager.update_content_status(
  61. doc_id=doc_id,
  62. ori_status=self.INIT_STATUS,
  63. new_status=self.PROCESSING_STATUS,
  64. )
  65. return raw_chunks
  66. async def insert_into_es(self, milvus_id, chunk: Chunk) -> int:
  67. docs = [
  68. {
  69. "_index": ELASTIC_SEARCH_INDEX,
  70. "_id": milvus_id,
  71. "_source": {
  72. "milvus_id": milvus_id,
  73. "doc_id": chunk.doc_id,
  74. "dataset_id": chunk.dataset_id,
  75. "chunk_id": chunk.chunk_id,
  76. "topic": chunk.topic,
  77. "domain": chunk.domain,
  78. "task_type": chunk.task_type,
  79. "text_type": chunk.text_type,
  80. "keywords": chunk.keywords,
  81. "concepts": chunk.concepts,
  82. "entities": chunk.entities,
  83. "status": chunk.status,
  84. },
  85. }
  86. ]
  87. resp = await self.es_client.bulk_insert(docs)
  88. return resp["success"]
  89. async def save_each_chunk(self, chunk: Chunk):
  90. # insert
  91. flag = await self.chunk_manager.insert_chunk(chunk)
  92. if not flag:
  93. print("插入文本失败")
  94. return
  95. acquire_lock = await self.chunk_manager.update_chunk_status(
  96. doc_id=chunk.doc_id,
  97. chunk_id=chunk.chunk_id,
  98. ori_status=self.INIT_STATUS,
  99. new_status=self.PROCESSING_STATUS,
  100. )
  101. if not acquire_lock:
  102. print("抢占文本分块锁失败")
  103. return
  104. completion = await self.classifier.classify_chunk(chunk)
  105. if not completion:
  106. await self.chunk_manager.update_chunk_status(
  107. doc_id=chunk.doc_id,
  108. chunk_id=chunk.chunk_id,
  109. ori_status=self.PROCESSING_STATUS,
  110. new_status=self.FAILED_STATUS,
  111. )
  112. print("从deepseek获取信息失败")
  113. return
  114. update_flag = await self.chunk_manager.set_chunk_result(
  115. chunk=completion,
  116. ori_status=self.PROCESSING_STATUS,
  117. new_status=self.FINISHED_STATUS,
  118. )
  119. if not update_flag:
  120. await self.chunk_manager.update_chunk_status(
  121. doc_id=chunk.doc_id,
  122. chunk_id=chunk.chunk_id,
  123. ori_status=self.PROCESSING_STATUS,
  124. new_status=self.FAILED_STATUS,
  125. )
  126. return
  127. milvus_id = await self.save_to_milvus(completion)
  128. if not milvus_id:
  129. return
  130. # 存储到 es 中
  131. # acquire_lock
  132. acquire_es_lock = await self.chunk_manager.update_es_status(
  133. doc_id=chunk.doc_id,
  134. chunk_id=chunk.chunk_id,
  135. ori_status=self.INIT_STATUS,
  136. new_status=self.PROCESSING_STATUS,
  137. )
  138. if not acquire_es_lock:
  139. print(f"获取 es Lock Fail: {chunk.doc_id}--{chunk.chunk_id}")
  140. return
  141. insert_rows = await self.insert_into_es(milvus_id, completion)
  142. final_status = self.FINISHED_STATUS if insert_rows else self.FAILED_STATUS
  143. await self.chunk_manager.update_es_status(
  144. doc_id=chunk.doc_id,
  145. chunk_id=chunk.chunk_id,
  146. ori_status=self.PROCESSING_STATUS,
  147. new_status=final_status,
  148. )
  149. async def save_to_milvus(self, chunk: Chunk):
  150. """
  151. :param chunk: each single chunk
  152. :return:
  153. """
  154. # 抢锁
  155. acquire_lock = await self.chunk_manager.update_embedding_status(
  156. doc_id=chunk.doc_id,
  157. chunk_id=chunk.chunk_id,
  158. new_status=self.PROCESSING_STATUS,
  159. ori_status=self.INIT_STATUS,
  160. )
  161. if not acquire_lock:
  162. print(f"抢占-{chunk.doc_id}-{chunk.chunk_id}分块-embedding处理锁失败")
  163. return None
  164. try:
  165. data = {
  166. "doc_id": chunk.doc_id,
  167. "chunk_id": chunk.chunk_id,
  168. "vector_text": await self.get_embedding_list(chunk.text),
  169. "vector_summary": await self.get_embedding_list(chunk.summary),
  170. "vector_questions": await self.get_embedding_list(
  171. ",".join(chunk.questions)
  172. ),
  173. }
  174. resp = await async_insert_chunk(self.milvus_client, data)
  175. if not resp:
  176. await self.chunk_manager.update_embedding_status(
  177. doc_id=chunk.doc_id,
  178. chunk_id=chunk.chunk_id,
  179. ori_status=self.PROCESSING_STATUS,
  180. new_status=self.FAILED_STATUS,
  181. )
  182. return None
  183. await self.chunk_manager.update_embedding_status(
  184. doc_id=chunk.doc_id,
  185. chunk_id=chunk.chunk_id,
  186. ori_status=self.PROCESSING_STATUS,
  187. new_status=self.FINISHED_STATUS,
  188. )
  189. milvus_id = resp[0]
  190. return milvus_id
  191. except Exception as e:
  192. await self.chunk_manager.update_embedding_status(
  193. doc_id=chunk.doc_id,
  194. chunk_id=chunk.chunk_id,
  195. ori_status=self.PROCESSING_STATUS,
  196. new_status=self.FAILED_STATUS,
  197. )
  198. print(f"存入向量数据库失败", e)
  199. return None
  200. async def deal(self, data):
  201. text = data.get("text", "")
  202. title = data.get("title", "")
  203. text, title = text.strip(), title.strip()
  204. text_type = data.get("text_type", 1)
  205. dataset_id = data.get("dataset_id", 0) # 默认知识库 id 为 0
  206. re_chunk = data.get("re_chunk", False)
  207. if not text:
  208. return None
  209. self.init_processer()
  210. async def _process():
  211. chunks = await self._chunk_each_content(
  212. self.doc_id, text, text_type, title, dataset_id, re_chunk
  213. )
  214. print(chunks)
  215. if not chunks:
  216. return
  217. # # dev
  218. # for chunk in chunks:
  219. # await self.save_each_chunk(chunk)
  220. await run_tasks_with_asyncio_task_group(
  221. task_list=chunks,
  222. handler=self.save_each_chunk,
  223. description="处理单篇文章分块",
  224. unit="chunk",
  225. max_concurrency=10,
  226. )
  227. await self.content_manager.update_content_status(
  228. doc_id=self.doc_id,
  229. ori_status=self.PROCESSING_STATUS,
  230. new_status=self.FINISHED_STATUS,
  231. )
  232. asyncio.create_task(_process())
  233. return self.doc_id