123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- import asyncio
- import uuid
- from typing import List
- from applications.api import get_basic_embedding
- from applications.utils.async_utils import run_tasks_with_asyncio_task_group
- from applications.utils.mysql import ContentChunks, Contents
- from applications.utils.chunks import TopicAwareChunker, LLMClassifier
- from applications.utils.milvus import async_insert_chunk
- from applications.config import Chunk, ChunkerConfig, DEFAULT_MODEL
- class ChunkEmbeddingTask(TopicAwareChunker):
- def __init__(self, mysql_pool, vector_pool, cfg: ChunkerConfig, doc_id):
- super().__init__(cfg, doc_id)
- self.content_chunk_processor = None
- self.contents_processor = None
- self.mysql_pool = mysql_pool
- self.vector_pool = vector_pool
- self.classifier = LLMClassifier()
- @staticmethod
- async def get_embedding_list(text: str) -> List:
- return await get_basic_embedding(text=text, model=DEFAULT_MODEL)
- def init_processer(self):
- self.contents_processor = Contents(self.mysql_pool)
- self.content_chunk_processor = ContentChunks(self.mysql_pool)
- async def process_content(self, doc_id, text) -> List[Chunk]:
- flag = await self.contents_processor.insert_content(doc_id, text)
- if not flag:
- return []
- else:
- raw_chunks = await self.chunk(text)
- if not raw_chunks:
- await self.contents_processor.update_content_status(
- doc_id=doc_id,
- ori_status=self.INIT_STATUS,
- new_status=self.FAILED_STATUS,
- )
- return []
- affected_rows = await self.contents_processor.update_content_status(
- doc_id=doc_id,
- ori_status=self.INIT_STATUS,
- new_status=self.PROCESSING_STATUS,
- )
- return raw_chunks
- async def process_each_chunk(self, chunk: Chunk):
- # insert
- flag = await self.content_chunk_processor.insert_chunk(chunk)
- if not flag:
- print("插入文本失败")
- return
- acquire_lock = await self.content_chunk_processor.update_chunk_status(
- doc_id=chunk.doc_id,
- chunk_id=chunk.chunk_id,
- ori_status=self.INIT_STATUS,
- new_status=self.PROCESSING_STATUS,
- )
- if not acquire_lock:
- print("抢占文本分块锁失败")
- return
- completion = await self.classifier.classify_chunk(chunk)
- if not completion:
- await self.content_chunk_processor.update_chunk_status(
- doc_id=chunk.doc_id,
- chunk_id=chunk.chunk_id,
- ori_status=self.PROCESSING_STATUS,
- new_status=self.FAILED_STATUS,
- )
- print("从deepseek获取信息失败")
- return
- update_flag = await self.content_chunk_processor.set_chunk_result(
- chunk=completion,
- ori_status=self.PROCESSING_STATUS,
- new_status=self.FINISHED_STATUS,
- )
- if not update_flag:
- await self.content_chunk_processor.update_chunk_status(
- doc_id=chunk.doc_id,
- chunk_id=chunk.chunk_id,
- ori_status=self.PROCESSING_STATUS,
- new_status=self.FAILED_STATUS,
- )
- return
- await self.save_to_milvus(completion)
- async def save_to_milvus(self, chunk: Chunk):
- """
- :param chunk: each single chunk
- :return:
- """
- # 抢锁
- acquire_lock = await self.content_chunk_processor.update_embedding_status(
- doc_id=chunk.doc_id,
- chunk_id=chunk.chunk_id,
- new_status=self.PROCESSING_STATUS,
- ori_status=self.INIT_STATUS,
- )
- if not acquire_lock:
- print(f"抢占-{chunk.doc_id}-{chunk.chunk_id}分块-embedding处理锁失败")
- return
- try:
- data = {
- "doc_id": chunk.doc_id,
- "chunk_id": chunk.chunk_id,
- "vector_text": await self.get_embedding_list(chunk.text),
- "vector_summary": await self.get_embedding_list(chunk.summary),
- "vector_questions": await self.get_embedding_list(
- ",".join(chunk.questions)
- ),
- "topic": chunk.topic,
- "domain": chunk.domain,
- "task_type": chunk.task_type,
- "summary": chunk.summary,
- "keywords": chunk.keywords,
- "entities": chunk.entities,
- "concepts": chunk.concepts,
- "questions": chunk.questions,
- "topic_purity": chunk.topic_purity,
- "tokens": chunk.tokens,
- }
- await async_insert_chunk(self.vector_pool, data)
- await self.content_chunk_processor.update_embedding_status(
- doc_id=chunk.doc_id,
- chunk_id=chunk.chunk_id,
- ori_status=self.PROCESSING_STATUS,
- new_status=self.FINISHED_STATUS,
- )
- except Exception as e:
- await self.content_chunk_processor.update_embedding_status(
- doc_id=chunk.doc_id,
- chunk_id=chunk.chunk_id,
- ori_status=self.PROCESSING_STATUS,
- new_status=self.FAILED_STATUS,
- )
- print(f"存入向量数据库失败", e)
- async def deal(self, data):
- text = data.get("text", "")
- text = text.strip()
- if not text:
- return None
- self.init_processer()
- async def _process():
- chunks = await self.process_content(self.doc_id, text)
- if not chunks:
- return
- # # dev
- # for chunk in chunks:
- # await self.process_each_chunk(chunk)
- await run_tasks_with_asyncio_task_group(
- task_list=chunks,
- handler=self.process_each_chunk,
- description="处理单篇文章分块",
- unit="chunk",
- max_concurrency=10,
- )
- await self.contents_processor.update_content_status(
- doc_id=self.doc_id,
- ori_status=self.PROCESSING_STATUS,
- new_status=self.FINISHED_STATUS,
- )
- asyncio.create_task(_process())
- return self.doc_id
|