123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254 |
- import asyncio
- from typing import List
- from applications.api import get_basic_embedding
- from applications.utils.async_utils import run_tasks_with_asyncio_task_group
- from applications.utils.chunks import LLMClassifier, TopicAwarePackerV2
- from applications.utils.milvus import async_insert_chunk
- from applications.utils.mysql import ContentChunks, Contents
- from applications.config import Chunk, DEFAULT_MODEL
- from applications.config import ELASTIC_SEARCH_INDEX
- class ChunkEmbeddingTask(TopicAwarePackerV2):
- def __init__(self, doc_id, resource):
- super().__init__(doc_id)
- self.chunk_manager = None
- self.content_manager = None
- self.mysql_client = resource.mysql_client
- self.milvus_client = resource.milvus_client
- self.es_client = resource.es_client
- self.classifier = LLMClassifier()
- @staticmethod
- async def get_embedding_list(text: str) -> List:
- return await get_basic_embedding(text=text, model=DEFAULT_MODEL)
- def init_processer(self):
- self.content_manager = Contents(self.mysql_client)
- self.chunk_manager = ContentChunks(self.mysql_client)
- async def _chunk_each_content(
- self,
- doc_id: str,
- text: str,
- text_type: int,
- title: str,
- dataset_id: int,
- re_chunk: bool,
- ) -> List[Chunk]:
- if re_chunk:
- flag = await self.content_manager.update_content_info(
- doc_id=doc_id,
- text=text,
- text_type=text_type,
- title=title,
- dataset_id=dataset_id,
- )
- else:
- flag = await self.content_manager.insert_content(
- doc_id, text, text_type, title, dataset_id
- )
- if not flag:
- return []
- else:
- raw_chunks = await self.chunk(text, text_type, dataset_id)
- if not raw_chunks:
- await self.content_manager.update_content_status(
- doc_id=doc_id,
- ori_status=self.INIT_STATUS,
- new_status=self.FAILED_STATUS,
- )
- return []
- await self.content_manager.update_content_status(
- doc_id=doc_id,
- ori_status=self.INIT_STATUS,
- new_status=self.PROCESSING_STATUS,
- )
- return raw_chunks
- async def insert_into_es(self, milvus_id, chunk: Chunk) -> int:
- docs = [
- {
- "_index": ELASTIC_SEARCH_INDEX,
- "_id": milvus_id,
- "_source": {
- "milvus_id": milvus_id,
- "doc_id": chunk.doc_id,
- "dataset_id": chunk.dataset_id,
- "chunk_id": chunk.chunk_id,
- "topic": chunk.topic,
- "domain": chunk.domain,
- "task_type": chunk.task_type,
- "text_type": chunk.text_type,
- "keywords": chunk.keywords,
- "concepts": chunk.concepts,
- "entities": chunk.entities,
- "status": chunk.status,
- },
- }
- ]
- resp = await self.es_client.bulk_insert(docs)
- return resp["success"]
- async def save_each_chunk(self, chunk: Chunk):
- # insert
- flag = await self.chunk_manager.insert_chunk(chunk)
- if not flag:
- print("插入文本失败")
- return
- acquire_lock = await self.chunk_manager.update_chunk_status(
- doc_id=chunk.doc_id,
- chunk_id=chunk.chunk_id,
- ori_status=self.INIT_STATUS,
- new_status=self.PROCESSING_STATUS,
- )
- if not acquire_lock:
- print("抢占文本分块锁失败")
- return
- completion = await self.classifier.classify_chunk(chunk)
- if not completion:
- await self.chunk_manager.update_chunk_status(
- doc_id=chunk.doc_id,
- chunk_id=chunk.chunk_id,
- ori_status=self.PROCESSING_STATUS,
- new_status=self.FAILED_STATUS,
- )
- print("从deepseek获取信息失败")
- return
- update_flag = await self.chunk_manager.set_chunk_result(
- chunk=completion,
- ori_status=self.PROCESSING_STATUS,
- new_status=self.FINISHED_STATUS,
- )
- if not update_flag:
- await self.chunk_manager.update_chunk_status(
- doc_id=chunk.doc_id,
- chunk_id=chunk.chunk_id,
- ori_status=self.PROCESSING_STATUS,
- new_status=self.FAILED_STATUS,
- )
- return
- milvus_id = await self.save_to_milvus(completion)
- if not milvus_id:
- return
- # 存储到 es 中
- # acquire_lock
- acquire_es_lock = await self.chunk_manager.update_es_status(
- doc_id=chunk.doc_id,
- chunk_id=chunk.chunk_id,
- ori_status=self.INIT_STATUS,
- new_status=self.PROCESSING_STATUS,
- )
- if not acquire_es_lock:
- print(f"获取 es Lock Fail: {chunk.doc_id}--{chunk.chunk_id}")
- return
- insert_rows = await self.insert_into_es(milvus_id, completion)
- final_status = self.FINISHED_STATUS if insert_rows else self.FAILED_STATUS
- await self.chunk_manager.update_es_status(
- doc_id=chunk.doc_id,
- chunk_id=chunk.chunk_id,
- ori_status=self.PROCESSING_STATUS,
- new_status=final_status,
- )
- async def save_to_milvus(self, chunk: Chunk):
- """
- :param chunk: each single chunk
- :return:
- """
- # 抢锁
- acquire_lock = await self.chunk_manager.update_embedding_status(
- doc_id=chunk.doc_id,
- chunk_id=chunk.chunk_id,
- new_status=self.PROCESSING_STATUS,
- ori_status=self.INIT_STATUS,
- )
- if not acquire_lock:
- print(f"抢占-{chunk.doc_id}-{chunk.chunk_id}分块-embedding处理锁失败")
- return None
- try:
- data = {
- "doc_id": chunk.doc_id,
- "chunk_id": chunk.chunk_id,
- "vector_text": await self.get_embedding_list(chunk.text),
- "vector_summary": await self.get_embedding_list(chunk.summary),
- "vector_questions": await self.get_embedding_list(
- ",".join(chunk.questions)
- ),
- }
- resp = await async_insert_chunk(self.milvus_client, data)
- if not resp:
- await self.chunk_manager.update_embedding_status(
- doc_id=chunk.doc_id,
- chunk_id=chunk.chunk_id,
- ori_status=self.PROCESSING_STATUS,
- new_status=self.FAILED_STATUS,
- )
- return None
- await self.chunk_manager.update_embedding_status(
- doc_id=chunk.doc_id,
- chunk_id=chunk.chunk_id,
- ori_status=self.PROCESSING_STATUS,
- new_status=self.FINISHED_STATUS,
- )
- milvus_id = resp[0]
- return milvus_id
- except Exception as e:
- await self.chunk_manager.update_embedding_status(
- doc_id=chunk.doc_id,
- chunk_id=chunk.chunk_id,
- ori_status=self.PROCESSING_STATUS,
- new_status=self.FAILED_STATUS,
- )
- print(f"存入向量数据库失败", e)
- return None
- async def deal(self, data):
- text = data.get("text", "")
- title = data.get("title", "")
- text, title = text.strip(), title.strip()
- text_type = data.get("text_type", 1)
- dataset_id = data.get("dataset_id", 0) # 默认知识库 id 为 0
- re_chunk = data.get("re_chunk", False)
- if not text:
- return None
- self.init_processer()
- async def _process():
- chunks = await self._chunk_each_content(
- self.doc_id, text, text_type, title, dataset_id, re_chunk
- )
- if not chunks:
- return
- # # dev
- # for chunk in chunks:
- # await self.save_each_chunk(chunk)
- await run_tasks_with_asyncio_task_group(
- task_list=chunks,
- handler=self.save_each_chunk,
- description="处理单篇文章分块",
- unit="chunk",
- max_concurrency=10,
- )
- await self.content_manager.update_content_status(
- doc_id=self.doc_id,
- ori_status=self.PROCESSING_STATUS,
- new_status=self.FINISHED_STATUS,
- )
- asyncio.create_task(_process())
- return self.doc_id
|