chunk_task.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. import asyncio
  2. import json
  3. from typing import List
  4. from applications.api import get_basic_embedding
  5. from applications.utils.async_utils import run_tasks_with_asyncio_task_group
  6. from applications.utils.chunks import LLMClassifier, TopicAwarePackerV2
  7. from applications.utils.milvus import async_insert_chunk
  8. from applications.utils.mysql import Books, ContentChunks, Contents
  9. from applications.utils.nlp import num_tokens
  10. from applications.config import Chunk, DEFAULT_MODEL
  11. from applications.config import ELASTIC_SEARCH_INDEX
  12. class ChunkEmbeddingTask(TopicAwarePackerV2):
  13. def __init__(self, doc_id, resource):
  14. super().__init__(doc_id)
  15. self.chunk_manager = None
  16. self.content_manager = None
  17. self.book_manager = None
  18. self.mysql_client = resource.mysql_client
  19. self.milvus_client = resource.milvus_client
  20. self.es_client = resource.es_client
  21. self.classifier = LLMClassifier()
  22. @staticmethod
  23. async def get_embedding_list(text: str) -> List:
  24. return await get_basic_embedding(text=text, model=DEFAULT_MODEL)
  25. def init_processer(self):
  26. self.content_manager = Contents(self.mysql_client)
  27. self.chunk_manager = ContentChunks(self.mysql_client)
  28. self.book_manager = Books(self.mysql_client)
  29. async def _chunk_each_content(self, doc_id: str, data: dict) -> List[Chunk]:
  30. title, text = data.get("title", "").strip(), data["text"].strip()
  31. text_type = data.get("text_type", 1)
  32. dataset_id = data.get("dataset_id", 0) # 默认知识库 id 为 0
  33. re_chunk = data.get("re_chunk", False)
  34. dont_chunk = data.get("dont_chunk", False)
  35. ext = data.get("ext", None)
  36. if title is None:
  37. if ext and isinstance(ext, str):
  38. try:
  39. ext_dict = json.loads(ext)
  40. title = ext_dict.get("query", None)
  41. except json.JSONDecodeError:
  42. title = None
  43. else:
  44. title = None
  45. if re_chunk:
  46. await self.content_manager.update_content_info(
  47. doc_id=doc_id,
  48. text=text,
  49. text_type=text_type,
  50. title=title,
  51. dataset_id=dataset_id,
  52. )
  53. flag = True
  54. else:
  55. flag = await self.content_manager.insert_content(
  56. doc_id, text, text_type, title, dataset_id, ext
  57. )
  58. if not flag:
  59. return []
  60. else:
  61. raw_chunks = await self.chunk(text, text_type, dataset_id, dont_chunk)
  62. if not raw_chunks:
  63. await self.content_manager.update_content_status(
  64. doc_id=doc_id,
  65. ori_status=self.INIT_STATUS,
  66. new_status=self.FAILED_STATUS,
  67. )
  68. return []
  69. await self.content_manager.update_content_status(
  70. doc_id=doc_id,
  71. ori_status=self.INIT_STATUS,
  72. new_status=self.PROCESSING_STATUS,
  73. )
  74. return raw_chunks
  75. async def insert_into_es(self, milvus_id, chunk: Chunk) -> int:
  76. docs = [
  77. {
  78. "_index": ELASTIC_SEARCH_INDEX,
  79. "_id": milvus_id,
  80. "_source": {
  81. "milvus_id": milvus_id,
  82. "doc_id": chunk.doc_id,
  83. "dataset_id": chunk.dataset_id,
  84. "chunk_id": chunk.chunk_id,
  85. "topic": chunk.topic,
  86. "domain": chunk.domain,
  87. "task_type": chunk.task_type,
  88. "text_type": chunk.text_type,
  89. "keywords": chunk.keywords,
  90. "concepts": chunk.concepts,
  91. "entities": chunk.entities,
  92. "status": chunk.status,
  93. },
  94. }
  95. ]
  96. resp = await self.es_client.bulk_insert(docs)
  97. return resp["success"]
  98. async def save_each_chunk(self, chunk: Chunk):
  99. # insert
  100. flag = await self.chunk_manager.insert_chunk(chunk)
  101. if not flag:
  102. print("插入文本失败")
  103. return
  104. acquire_lock = await self.chunk_manager.update_chunk_status(
  105. doc_id=chunk.doc_id,
  106. chunk_id=chunk.chunk_id,
  107. ori_status=self.INIT_STATUS,
  108. new_status=self.PROCESSING_STATUS,
  109. )
  110. if not acquire_lock:
  111. print("抢占文本分块锁失败")
  112. return
  113. completion = await self.classifier.classify_chunk(chunk)
  114. if not completion:
  115. await self.chunk_manager.update_chunk_status(
  116. doc_id=chunk.doc_id,
  117. chunk_id=chunk.chunk_id,
  118. ori_status=self.PROCESSING_STATUS,
  119. new_status=self.FAILED_STATUS,
  120. )
  121. print("从deepseek获取信息失败")
  122. return
  123. update_flag = await self.chunk_manager.set_chunk_result(
  124. chunk=completion,
  125. ori_status=self.PROCESSING_STATUS,
  126. new_status=self.FINISHED_STATUS,
  127. )
  128. if not update_flag:
  129. await self.chunk_manager.update_chunk_status(
  130. doc_id=chunk.doc_id,
  131. chunk_id=chunk.chunk_id,
  132. ori_status=self.PROCESSING_STATUS,
  133. new_status=self.FAILED_STATUS,
  134. )
  135. return
  136. milvus_id = await self.save_to_milvus(completion)
  137. if not milvus_id:
  138. return
  139. # 存储到 es 中
  140. # acquire_lock
  141. acquire_es_lock = await self.chunk_manager.update_es_status(
  142. doc_id=chunk.doc_id,
  143. chunk_id=chunk.chunk_id,
  144. ori_status=self.INIT_STATUS,
  145. new_status=self.PROCESSING_STATUS,
  146. )
  147. if not acquire_es_lock:
  148. print(f"获取 es Lock Fail: {chunk.doc_id}--{chunk.chunk_id}")
  149. return
  150. insert_rows = await self.insert_into_es(milvus_id, completion)
  151. final_status = self.FINISHED_STATUS if insert_rows else self.FAILED_STATUS
  152. await self.chunk_manager.update_es_status(
  153. doc_id=chunk.doc_id,
  154. chunk_id=chunk.chunk_id,
  155. ori_status=self.PROCESSING_STATUS,
  156. new_status=final_status,
  157. )
  158. async def save_to_milvus(self, chunk: Chunk):
  159. """
  160. :param chunk: each single chunk
  161. :return:
  162. """
  163. # 抢锁
  164. acquire_lock = await self.chunk_manager.update_embedding_status(
  165. doc_id=chunk.doc_id,
  166. chunk_id=chunk.chunk_id,
  167. new_status=self.PROCESSING_STATUS,
  168. ori_status=self.INIT_STATUS,
  169. )
  170. if not acquire_lock:
  171. print(f"抢占-{chunk.doc_id}-{chunk.chunk_id}分块-embedding处理锁失败")
  172. return None
  173. try:
  174. data = {
  175. "doc_id": chunk.doc_id,
  176. "chunk_id": chunk.chunk_id,
  177. "vector_text": await self.get_embedding_list(chunk.text),
  178. "vector_summary": await self.get_embedding_list(chunk.summary),
  179. "vector_questions": await self.get_embedding_list(
  180. ",".join(chunk.questions)
  181. ),
  182. }
  183. resp = await async_insert_chunk(self.milvus_client, data)
  184. if not resp:
  185. await self.chunk_manager.update_embedding_status(
  186. doc_id=chunk.doc_id,
  187. chunk_id=chunk.chunk_id,
  188. ori_status=self.PROCESSING_STATUS,
  189. new_status=self.FAILED_STATUS,
  190. )
  191. return None
  192. await self.chunk_manager.update_embedding_status(
  193. doc_id=chunk.doc_id,
  194. chunk_id=chunk.chunk_id,
  195. ori_status=self.PROCESSING_STATUS,
  196. new_status=self.FINISHED_STATUS,
  197. )
  198. milvus_id = resp[0]
  199. return milvus_id
  200. except Exception as e:
  201. await self.chunk_manager.update_embedding_status(
  202. doc_id=chunk.doc_id,
  203. chunk_id=chunk.chunk_id,
  204. ori_status=self.PROCESSING_STATUS,
  205. new_status=self.FAILED_STATUS,
  206. )
  207. print(f"存入向量数据库失败", e)
  208. return None
  209. async def deal(self, data):
  210. text = data.get("text", "")
  211. dont_chunk = data.get("dont_chunk", False)
  212. dataset_id = data.get("dataset_id", 0)
  213. if dataset_id == 12:
  214. data["dont_chunk"] = True
  215. # 如果无需分块,判断text 长度
  216. if dont_chunk and num_tokens(text) >= self.max_tokens:
  217. data["dont_chunk"] = False
  218. # return {"error": "文档超多模型支持的最大吞吐量"}
  219. self.init_processer()
  220. async def _process():
  221. chunks = await self._chunk_each_content(self.doc_id, data)
  222. if not chunks:
  223. return
  224. # # dev
  225. # for chunk in tqdm(chunks):
  226. # await self.save_each_chunk(chunk)
  227. await run_tasks_with_asyncio_task_group(
  228. task_list=chunks,
  229. handler=self.save_each_chunk,
  230. description="处理单篇文章分块",
  231. unit="chunk",
  232. max_concurrency=10,
  233. )
  234. await self.content_manager.update_content_status(
  235. doc_id=self.doc_id,
  236. ori_status=self.PROCESSING_STATUS,
  237. new_status=self.FINISHED_STATUS,
  238. )
  239. asyncio.create_task(_process())
  240. return self.doc_id
  241. class ChunkBooksTask(ChunkEmbeddingTask):
  242. """图书类型分块任务"""
  243. BOOK_PDF_DATASET_ID = 17
  244. BOOK_PDF_TYPE = 3
  245. async def _process_each_book(self, book_id):
  246. result = await self.book_manager.get_book_extract_detail(book_id=book_id)
  247. extract_result = result[0]["extract_result"]
  248. book_name = result[0]["book_name"]
  249. book_oss_path = result[0]["book_oss_path"]
  250. book_texts = [
  251. i["text"] for i in json.loads(extract_result) if i["type"] == "text"
  252. ]
  253. # first insert into contents
  254. flag = await self.content_manager.insert_content(
  255. self.doc_id,
  256. book_oss_path,
  257. self.BOOK_PDF_TYPE,
  258. book_name,
  259. self.BOOK_PDF_DATASET_ID,
  260. ext=None,
  261. )
  262. if not flag:
  263. return []
  264. else:
  265. raw_chunks = await self.chunk_books(sentence_list=book_texts, text_type=self.BOOK_PDF_TYPE, dataset_id=self.BOOK_PDF_DATASET_ID)
  266. if not raw_chunks:
  267. await self.content_manager.update_content_status(
  268. doc_id=self.doc_id,
  269. ori_status=self.INIT_STATUS,
  270. new_status=self.FAILED_STATUS,
  271. )
  272. return []
  273. await self.content_manager.update_content_status(
  274. doc_id=self.doc_id,
  275. ori_status=self.INIT_STATUS,
  276. new_status=self.PROCESSING_STATUS,
  277. )
  278. return raw_chunks
  279. async def deal(self, data):
  280. book_id = data.get("book_id", None)
  281. if not book_id:
  282. return {"error": "Book id should not be None"}
  283. self.init_processer()
  284. # LOCK
  285. acquire_lock = await self.book_manager.update_book_chunk_status(
  286. book_id=book_id,
  287. ori_status=self.INIT_STATUS,
  288. new_status=self.PROCESSING_STATUS
  289. )
  290. print(acquire_lock)
  291. if not acquire_lock:
  292. return {
  293. "info": "book is processing or processed"
  294. }
  295. async def _process():
  296. chunks = await self._process_each_book(book_id)
  297. if not chunks:
  298. return
  299. # # dev
  300. # for chunk in tqdm(chunks):
  301. # await self.save_each_chunk(chunk)
  302. await run_tasks_with_asyncio_task_group(
  303. task_list=chunks,
  304. handler=self.save_each_chunk,
  305. description="处理单篇文章分块",
  306. unit="chunk",
  307. max_concurrency=10,
  308. )
  309. await self.content_manager.update_content_status(
  310. doc_id=self.doc_id,
  311. ori_status=self.PROCESSING_STATUS,
  312. new_status=self.FINISHED_STATUS,
  313. )
  314. await self.book_manager.update_book_chunk_status(
  315. book_id=book_id,
  316. ori_status=self.PROCESSING_STATUS,
  317. new_status=self.FINISHED_STATUS
  318. )
  319. asyncio.create_task(_process())
  320. return self.doc_id
  321. __all__ = ["ChunkEmbeddingTask", "ChunkBooksTask"]