luojunhui 3 هفته پیش
والد
کامیت
667dea3b45

+ 2 - 2
applications/async_task/__init__.py

@@ -1,4 +1,4 @@
-from .chunk_task import ChunkTask
+from .chunk_task import ChunkEmbeddingTask
 
 
-__all__ = ['ChunkTask']
+__all__ = ["ChunkEmbeddingTask"]

+ 94 - 21
applications/async_task/chunk_task.py

@@ -2,12 +2,15 @@ import asyncio
 import uuid
 from typing import List
 
+from applications.api import get_basic_embedding
+from applications.utils.async_utils import run_tasks_with_asyncio_task_group
 from applications.utils.mysql import ContentChunks, Contents
 from applications.utils.chunks import TopicAwareChunker, LLMClassifier
-from applications.config import DEFAULT_MODEL, Chunk, ChunkerConfig
+from applications.utils.milvus import async_insert_chunk
+from applications.config import Chunk, ChunkerConfig, DEFAULT_MODEL
 
 
-class ChunkTask(TopicAwareChunker):
+class ChunkEmbeddingTask(TopicAwareChunker):
     def __init__(self, mysql_pool, vector_pool, cfg: ChunkerConfig):
         super().__init__(cfg)
         self.content_chunk_processor = None
@@ -16,6 +19,10 @@ class ChunkTask(TopicAwareChunker):
         self.vector_pool = vector_pool
         self.classifier = LLMClassifier()
 
+    @staticmethod
+    async def get_embedding_list(text: str) -> List:
+        return await get_basic_embedding(text=text, model=DEFAULT_MODEL, dev=True)
+
     def init_processer(self):
         self.contents_processor = Contents(self.mysql_pool)
         self.content_chunk_processor = ContentChunks(self.mysql_pool)
@@ -28,14 +35,17 @@ class ChunkTask(TopicAwareChunker):
             raw_chunks = await self.chunk(text)
             if not raw_chunks:
                 await self.contents_processor.update_content_status(
-                    doc_id=doc_id, ori_status=self.INIT_STATUS, new_status=self.FAILED_STATUS
+                    doc_id=doc_id,
+                    ori_status=self.INIT_STATUS,
+                    new_status=self.FAILED_STATUS,
                 )
                 return []
 
             affected_rows = await self.contents_processor.update_content_status(
-                doc_id=doc_id, ori_status=self.INIT_STATUS, new_status=self.PROCESSING_STATUS
+                doc_id=doc_id,
+                ori_status=self.INIT_STATUS,
+                new_status=self.PROCESSING_STATUS,
             )
-            print(affected_rows)
             return raw_chunks
 
     async def process_each_chunk(self, chunk: Chunk):
@@ -45,7 +55,10 @@ class ChunkTask(TopicAwareChunker):
             return
 
         acquire_lock = await self.content_chunk_processor.update_chunk_status(
-            doc_id=chunk.doc_id, chunk_id=chunk.chunk_id, ori_status=self.INIT_STATUS, new_status=self.PROCESSING_STATUS
+            doc_id=chunk.doc_id,
+            chunk_id=chunk.chunk_id,
+            ori_status=self.INIT_STATUS,
+            new_status=self.PROCESSING_STATUS,
         )
         if not acquire_lock:
             return
@@ -53,17 +66,78 @@ class ChunkTask(TopicAwareChunker):
         completion = await self.classifier.classify_chunk(chunk)
         if not completion:
             await self.content_chunk_processor.update_chunk_status(
-                doc_id=chunk.doc_id, chunk_id=chunk.chunk_id, ori_status=self.PROCESSING_STATUS, new_status=self.FAILED_STATUS
+                doc_id=chunk.doc_id,
+                chunk_id=chunk.chunk_id,
+                ori_status=self.PROCESSING_STATUS,
+                new_status=self.FAILED_STATUS,
             )
+            return
 
         update_flag = await self.content_chunk_processor.set_chunk_result(
-            chunk=completion, new_status=self.FINISHED_STATUS, ori_status=self.PROCESSING_STATUS
+            chunk=completion,
+            ori_status=self.PROCESSING_STATUS,
+            new_status=self.FINISHED_STATUS,
         )
         if not update_flag:
             await self.content_chunk_processor.update_chunk_status(
-                doc_id=chunk.doc_id, chunk_id=chunk.chunk_id, ori_status=self.PROCESSING_STATUS, new_status=self.FAILED_STATUS
+                doc_id=chunk.doc_id,
+                chunk_id=chunk.chunk_id,
+                ori_status=self.PROCESSING_STATUS,
+                new_status=self.FAILED_STATUS,
             )
+            return
 
+        await self.save_to_milvus(completion)
+
+    async def save_to_milvus(self, chunk: Chunk):
+        """
+        :param chunk: each single chunk
+        :return:
+        """
+        # 抢锁
+        acquire_lock = await self.content_chunk_processor.update_embedding_status(
+            doc_id=chunk.doc_id,
+            chunk_id=chunk.chunk_id,
+            new_status=self.PROCESSING_STATUS,
+            ori_status=self.INIT_STATUS,
+        )
+        if not acquire_lock:
+            print(f"抢占-{chunk.doc_id}-{chunk.chunk_id}分块-embedding处理锁失败")
+            return
+        try:
+            data = {
+                "doc_id": chunk.doc_id,
+                "chunk_id": chunk.chunk_id,
+                "vector_text": await self.get_embedding_list(chunk.text),
+                "vector_summary": await self.get_embedding_list(chunk.summary),
+                "vector_questions": await self.get_embedding_list(
+                    ",".join(chunk.questions)
+                ),
+                "topic": chunk.topic,
+                "domain": chunk.domain,
+                "task_type": chunk.task_type,
+                "summary": chunk.summary,
+                "keywords": chunk.keywords,
+                "concepts": chunk.concepts,
+                "questions": chunk.questions,
+                "topic_purity": chunk.topic_purity,
+                "tokens": chunk.tokens,
+            }
+            await async_insert_chunk(self.vector_pool, data)
+            await self.content_chunk_processor.update_embedding_status(
+                doc_id=chunk.doc_id,
+                chunk_id=chunk.chunk_id,
+                ori_status=self.PROCESSING_STATUS,
+                new_status=self.FINISHED_STATUS,
+            )
+        except Exception as e:
+            await self.content_chunk_processor.update_embedding_status(
+                doc_id=chunk.doc_id,
+                chunk_id=chunk.chunk_id,
+                ori_status=self.PROCESSING_STATUS,
+                new_status=self.FAILED_STATUS,
+            )
+            print(f"存入向量数据库失败", e)
 
     async def deal(self, data):
         text = data.get("text")
@@ -78,20 +152,19 @@ class ChunkTask(TopicAwareChunker):
             if not chunks:
                 return
 
-            # 开始分batch
-            async with asyncio.TaskGroup() as tg:
-                for chunk in chunks:
-                    tg.create_task(self.process_each_chunk(chunk))
+            await run_tasks_with_asyncio_task_group(
+                task_list=chunks,
+                handler=self.process_each_chunk,
+                description="处理单篇文章分块",
+                unit="chunk",
+                max_concurrency=10,
+            )
 
             await self.contents_processor.update_content_status(
-                doc_id=doc_id, ori_status=self.PROCESSING_STATUS, new_status=self.FINISHED_STATUS
+                doc_id=doc_id,
+                ori_status=self.PROCESSING_STATUS,
+                new_status=self.FINISHED_STATUS,
             )
 
-        await _process()
-        # asyncio.create_task(_process())
+        asyncio.create_task(_process())
         return doc_id
-
-
-
-
-

+ 3 - 0
applications/utils/async_utils/__init__.py

@@ -0,0 +1,3 @@
+from .group_task import run_tasks_with_asyncio_task_group
+
+__all__ = ["run_tasks_with_asyncio_task_group"]

+ 51 - 0
applications/utils/async_utils/group_task.py

@@ -0,0 +1,51 @@
+import asyncio
+from typing import Callable, Coroutine, List, Any, Dict
+
+from tqdm.asyncio import tqdm
+
+
+# 使用asyncio.TaskGroup 来高效处理I/O密集型任务
+async def run_tasks_with_asyncio_task_group(
+    task_list: List[Any],
+    handler: Callable[[Any], Coroutine[Any, Any, None]],
+    *,
+    description: str = None,  # 任务介绍
+    unit: str,
+    max_concurrency: int = 20,  # 最大并发数
+    fail_fast: bool = False,  # 是否遇到错误就退出整个tasks
+) -> Dict[str, Any]:
+    """using asyncio.TaskGroup to process I/O-intensive tasks"""
+    if not task_list:
+        return {"total_task": 0, "processed_task": 0, "errors": []}
+
+    processed_task = 0
+    total_task = len(task_list)
+    errors: List[tuple[int, Any, Exception]] = []
+    semaphore = asyncio.Semaphore(max_concurrency)
+    processing_bar = tqdm(total=total_task, unit=unit, desc=description)
+
+    async def _run_single_task(task_obj: Any, idx: int):
+        nonlocal processed_task
+        async with semaphore:
+            try:
+                await handler(task_obj)
+                processed_task += 1
+            except Exception as e:
+                if fail_fast:
+                    raise e
+                errors.append((idx, task_obj, e))
+            finally:
+                processing_bar.update()
+
+    async with asyncio.TaskGroup() as task_group:
+        for index, task in enumerate(task_list, start=1):
+            task_group.create_task(
+                _run_single_task(task, index), name=f"processing {description}-{index}"
+            )
+
+    processing_bar.close()
+    return {
+        "total_task": total_task,
+        "processed_task": processed_task,
+        "errors": errors,
+    }

+ 1 - 1
applications/utils/milvus/field.py

@@ -7,7 +7,7 @@ fields = [
         dtype=DataType.INT64,
         is_primary=True,
         auto_id=True,
-        description="自增逐渐id",
+        description="自增id",
     ),
     FieldSchema(
         name="doc_id", dtype=DataType.VARCHAR, max_length=64, description="文档id"

+ 13 - 5
applications/utils/mysql/mapper.py

@@ -72,7 +72,7 @@ class ContentChunks(BaseMySQLClient):
             WHERE doc_id = %s AND chunk_id = %s AND embedding_status = %s;
         """
         return await self.pool.async_save(
-            query=query, params=(new_status, doc_id, chunk_id,  ori_status)
+            query=query, params=(new_status, doc_id, chunk_id, ori_status)
         )
 
     async def set_chunk_result(self, chunk: Chunk, ori_status, new_status):
@@ -84,8 +84,16 @@ class ContentChunks(BaseMySQLClient):
         return await self.pool.async_save(
             query=query,
             params=(
-                chunk.summary, chunk.topic, chunk.domain, chunk.task_type,
-                json.dumps(chunk.concepts), json.dumps(chunk.keywords), json.dumps(chunk.questions), new_status,
-                chunk.doc_id, chunk.chunk_id, ori_status
-            )
+                chunk.summary,
+                chunk.topic,
+                chunk.domain,
+                chunk.task_type,
+                json.dumps(chunk.concepts),
+                json.dumps(chunk.keywords),
+                json.dumps(chunk.questions),
+                new_status,
+                chunk.doc_id,
+                chunk.chunk_id,
+                ori_status,
+            ),
         )

+ 2 - 2
routes/buleprint.py

@@ -2,7 +2,7 @@ from quart import Blueprint, jsonify, request
 
 from applications.config import DEFAULT_MODEL, LOCAL_MODEL_CONFIG, ChunkerConfig
 from applications.api import get_basic_embedding
-from applications.async_task import ChunkTask
+from applications.async_task import ChunkEmbeddingTask
 
 server_bp = Blueprint("api", __name__, url_prefix="/api")
 
@@ -27,7 +27,7 @@ def server_routes(mysql_db, vector_db):
         if not text:
             return jsonify({"error": "error  text"})
 
-        chunk_task = ChunkTask(mysql_db, vector_db, cfg=ChunkerConfig())
+        chunk_task = ChunkEmbeddingTask(mysql_db, vector_db, cfg=ChunkerConfig())
         doc_id = await chunk_task.deal(body)
         return jsonify({"doc_id": doc_id})