chunk_task.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. import asyncio
  2. import json
  3. from typing import List
  4. from applications.api import get_basic_embedding
  5. from applications.utils.async_utils import run_tasks_with_asyncio_task_group
  6. from applications.utils.chunks import LLMClassifier, TopicAwarePackerV2
  7. from applications.utils.milvus import async_insert_chunk
  8. from applications.utils.mysql import ContentChunks, Contents
  9. from applications.utils.nlp import num_tokens
  10. from applications.config import Chunk, DEFAULT_MODEL
  11. from applications.config import ELASTIC_SEARCH_INDEX
  12. class ChunkEmbeddingTask(TopicAwarePackerV2):
  13. def __init__(self, doc_id, resource):
  14. super().__init__(doc_id)
  15. self.chunk_manager = None
  16. self.content_manager = None
  17. self.mysql_client = resource.mysql_client
  18. self.milvus_client = resource.milvus_client
  19. self.es_client = resource.es_client
  20. self.classifier = LLMClassifier()
  21. @staticmethod
  22. async def get_embedding_list(text: str) -> List:
  23. return await get_basic_embedding(text=text, model=DEFAULT_MODEL)
  24. def init_processer(self):
  25. self.content_manager = Contents(self.mysql_client)
  26. self.chunk_manager = ContentChunks(self.mysql_client)
  27. async def _chunk_each_content(self, doc_id: str, data: dict) -> List[Chunk]:
  28. title, text = data.get("title", "").strip(), data["text"].strip()
  29. text_type = data.get("text_type", 1)
  30. dataset_id = data.get("dataset_id", 0) # 默认知识库 id 为 0
  31. re_chunk = data.get("re_chunk", False)
  32. dont_chunk = data.get("dont_chunk", False)
  33. ext = data.get("ext", None)
  34. if title is None:
  35. if ext and isinstance(ext, str):
  36. try:
  37. ext_dict = json.loads(ext)
  38. title = ext_dict.get("query", None)
  39. except json.JSONDecodeError:
  40. title = None
  41. else:
  42. title = None
  43. if re_chunk:
  44. await self.content_manager.update_content_info(
  45. doc_id=doc_id,
  46. text=text,
  47. text_type=text_type,
  48. title=title,
  49. dataset_id=dataset_id,
  50. )
  51. flag = True
  52. else:
  53. flag = await self.content_manager.insert_content(
  54. doc_id, text, text_type, title, dataset_id, ext
  55. )
  56. if not flag:
  57. return []
  58. else:
  59. raw_chunks = await self.chunk(text, text_type, dataset_id, dont_chunk)
  60. if not raw_chunks:
  61. await self.content_manager.update_content_status(
  62. doc_id=doc_id,
  63. ori_status=self.INIT_STATUS,
  64. new_status=self.FAILED_STATUS,
  65. )
  66. return []
  67. await self.content_manager.update_content_status(
  68. doc_id=doc_id,
  69. ori_status=self.INIT_STATUS,
  70. new_status=self.PROCESSING_STATUS,
  71. )
  72. return raw_chunks
  73. async def insert_into_es(self, milvus_id, chunk: Chunk) -> int:
  74. docs = [
  75. {
  76. "_index": ELASTIC_SEARCH_INDEX,
  77. "_id": milvus_id,
  78. "_source": {
  79. "milvus_id": milvus_id,
  80. "doc_id": chunk.doc_id,
  81. "dataset_id": chunk.dataset_id,
  82. "chunk_id": chunk.chunk_id,
  83. "topic": chunk.topic,
  84. "domain": chunk.domain,
  85. "task_type": chunk.task_type,
  86. "text_type": chunk.text_type,
  87. "keywords": chunk.keywords,
  88. "concepts": chunk.concepts,
  89. "entities": chunk.entities,
  90. "status": chunk.status,
  91. },
  92. }
  93. ]
  94. resp = await self.es_client.bulk_insert(docs)
  95. return resp["success"]
  96. async def save_each_chunk(self, chunk: Chunk):
  97. # insert
  98. flag = await self.chunk_manager.insert_chunk(chunk)
  99. if not flag:
  100. print("插入文本失败")
  101. return
  102. acquire_lock = await self.chunk_manager.update_chunk_status(
  103. doc_id=chunk.doc_id,
  104. chunk_id=chunk.chunk_id,
  105. ori_status=self.INIT_STATUS,
  106. new_status=self.PROCESSING_STATUS,
  107. )
  108. if not acquire_lock:
  109. print("抢占文本分块锁失败")
  110. return
  111. completion = await self.classifier.classify_chunk(chunk)
  112. if not completion:
  113. await self.chunk_manager.update_chunk_status(
  114. doc_id=chunk.doc_id,
  115. chunk_id=chunk.chunk_id,
  116. ori_status=self.PROCESSING_STATUS,
  117. new_status=self.FAILED_STATUS,
  118. )
  119. print("从deepseek获取信息失败")
  120. return
  121. update_flag = await self.chunk_manager.set_chunk_result(
  122. chunk=completion,
  123. ori_status=self.PROCESSING_STATUS,
  124. new_status=self.FINISHED_STATUS,
  125. )
  126. if not update_flag:
  127. await self.chunk_manager.update_chunk_status(
  128. doc_id=chunk.doc_id,
  129. chunk_id=chunk.chunk_id,
  130. ori_status=self.PROCESSING_STATUS,
  131. new_status=self.FAILED_STATUS,
  132. )
  133. return
  134. milvus_id = await self.save_to_milvus(completion)
  135. if not milvus_id:
  136. return
  137. # 存储到 es 中
  138. # acquire_lock
  139. acquire_es_lock = await self.chunk_manager.update_es_status(
  140. doc_id=chunk.doc_id,
  141. chunk_id=chunk.chunk_id,
  142. ori_status=self.INIT_STATUS,
  143. new_status=self.PROCESSING_STATUS,
  144. )
  145. if not acquire_es_lock:
  146. print(f"获取 es Lock Fail: {chunk.doc_id}--{chunk.chunk_id}")
  147. return
  148. insert_rows = await self.insert_into_es(milvus_id, completion)
  149. final_status = self.FINISHED_STATUS if insert_rows else self.FAILED_STATUS
  150. await self.chunk_manager.update_es_status(
  151. doc_id=chunk.doc_id,
  152. chunk_id=chunk.chunk_id,
  153. ori_status=self.PROCESSING_STATUS,
  154. new_status=final_status,
  155. )
  156. async def save_to_milvus(self, chunk: Chunk):
  157. """
  158. :param chunk: each single chunk
  159. :return:
  160. """
  161. # 抢锁
  162. acquire_lock = await self.chunk_manager.update_embedding_status(
  163. doc_id=chunk.doc_id,
  164. chunk_id=chunk.chunk_id,
  165. new_status=self.PROCESSING_STATUS,
  166. ori_status=self.INIT_STATUS,
  167. )
  168. if not acquire_lock:
  169. print(f"抢占-{chunk.doc_id}-{chunk.chunk_id}分块-embedding处理锁失败")
  170. return None
  171. try:
  172. data = {
  173. "doc_id": chunk.doc_id,
  174. "chunk_id": chunk.chunk_id,
  175. "vector_text": await self.get_embedding_list(chunk.text),
  176. "vector_summary": await self.get_embedding_list(chunk.summary),
  177. "vector_questions": await self.get_embedding_list(
  178. ",".join(chunk.questions)
  179. ),
  180. }
  181. resp = await async_insert_chunk(self.milvus_client, data)
  182. if not resp:
  183. await self.chunk_manager.update_embedding_status(
  184. doc_id=chunk.doc_id,
  185. chunk_id=chunk.chunk_id,
  186. ori_status=self.PROCESSING_STATUS,
  187. new_status=self.FAILED_STATUS,
  188. )
  189. return None
  190. await self.chunk_manager.update_embedding_status(
  191. doc_id=chunk.doc_id,
  192. chunk_id=chunk.chunk_id,
  193. ori_status=self.PROCESSING_STATUS,
  194. new_status=self.FINISHED_STATUS,
  195. )
  196. milvus_id = resp[0]
  197. return milvus_id
  198. except Exception as e:
  199. await self.chunk_manager.update_embedding_status(
  200. doc_id=chunk.doc_id,
  201. chunk_id=chunk.chunk_id,
  202. ori_status=self.PROCESSING_STATUS,
  203. new_status=self.FAILED_STATUS,
  204. )
  205. print(f"存入向量数据库失败", e)
  206. return None
  207. async def deal(self, data):
  208. text = data.get("text", "")
  209. dont_chunk = data.get("dont_chunk", False)
  210. dataset_id = data.get("dataset_id", 0)
  211. if dataset_id == 12:
  212. data["dont_chunk"] = True
  213. # 如果无需分块,判断text 长度
  214. if dont_chunk and num_tokens(text) >= self.max_tokens:
  215. data["dont_chunk"] = False
  216. # return {"error": "文档超多模型支持的最大吞吐量"}
  217. self.init_processer()
  218. async def _process():
  219. chunks = await self._chunk_each_content(self.doc_id, data)
  220. if not chunks:
  221. return
  222. # # dev
  223. # for chunk in tqdm(chunks):
  224. # await self.save_each_chunk(chunk)
  225. await run_tasks_with_asyncio_task_group(
  226. task_list=chunks,
  227. handler=self.save_each_chunk,
  228. description="处理单篇文章分块",
  229. unit="chunk",
  230. max_concurrency=10,
  231. )
  232. await self.content_manager.update_content_status(
  233. doc_id=self.doc_id,
  234. ori_status=self.PROCESSING_STATUS,
  235. new_status=self.FINISHED_STATUS,
  236. )
  237. asyncio.create_task(_process())
  238. return self.doc_id