Pārlūkot izejas kodu

新增 dont_chunk模块

luojunhui 1 nedēļu atpakaļ
vecāks
revīzija
d18cca7a58

+ 2 - 9
applications/async_task/chunk_task.py

@@ -29,11 +29,7 @@ class ChunkEmbeddingTask(TopicAwarePackerV2):
         self.content_manager = Contents(self.mysql_client)
         self.chunk_manager = ContentChunks(self.mysql_client)
 
-    async def _chunk_each_content(
-        self,
-        doc_id: str,
-        data: dict
-    ) -> List[Chunk]:
+    async def _chunk_each_content(self, doc_id: str, data: dict) -> List[Chunk]:
         title, text = data.get("title", "").strip(), data["text"].strip()
         text_type = data.get("text_type", 1)
         dataset_id = data.get("dataset_id", 0)  # 默认知识库 id 为 0
@@ -220,14 +216,11 @@ class ChunkEmbeddingTask(TopicAwarePackerV2):
         dont_chunk = data.get("dont_chunk", False)
         # 如果无需分块,判断text 长度
         if dont_chunk and num_tokens(text) >= self.max_tokens:
-            return {
-                "error": "文档超多模型支持的最大吞吐量"
-            }
+            return {"error": "文档超多模型支持的最大吞吐量"}
 
         self.init_processer()
 
         async def _process():
-
             chunks = await self._chunk_each_content(self.doc_id, data)
             if not chunks:
                 return

+ 3 - 1
applications/utils/chunks/topic_aware_chunking.py

@@ -161,7 +161,9 @@ class TopicAwarePackerV2(TopicAwareChunker):
 
         return chunks
 
-    async def chunk(self, text: str, text_type: int, dataset_id: int, dont_chunk: bool) -> List[Chunk]:
+    async def chunk(
+        self, text: str, text_type: int, dataset_id: int, dont_chunk: bool
+    ) -> List[Chunk]:
         raw_info = await self._raw_chunk(text, dont_chunk)
         if not raw_info:
             return []

+ 10 - 2
mcp_app.py

@@ -8,7 +8,12 @@ from starlette.applications import Starlette
 from starlette.routing import Mount
 from starlette.types import Receive, Scope, Send
 
-from applications.config import ES_HOSTS, ELASTIC_SEARCH_INDEX, ES_PASSWORD, MILVUS_CONFIG
+from applications.config import (
+    ES_HOSTS,
+    ELASTIC_SEARCH_INDEX,
+    ES_PASSWORD,
+    MILVUS_CONFIG,
+)
 from applications.resource import init_resource_manager
 from mcp_server.server import create_mcp_server
 
@@ -47,7 +52,9 @@ def main(port: int, host: str, json_response: bool) -> int:
     )
 
     # 处理Streamable HTTP请求
-    async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
+    async def handle_streamable_http(
+        scope: Scope, receive: Receive, send: Send
+    ) -> None:
         await session_manager.handle_request(scope, receive, send)
 
     # 定义生命周期管理
@@ -66,6 +73,7 @@ def main(port: int, host: str, json_response: bool) -> int:
 
     # 启动服务器
     import uvicorn
+
     uvicorn.run(starlette_app, host=host, port=port)
     return 0
 

+ 5 - 7
mcp_server/server.py

@@ -15,7 +15,9 @@ def create_mcp_server() -> Server:
     app = Server("mcp-rag-server")
 
     @app.call_tool()
-    async def call_tool(name: str, arguments: Dict[str, Any]) -> List[types.TextContent]:
+    async def call_tool(
+        name: str, arguments: Dict[str, Any]
+    ) -> List[types.TextContent]:
         """处理工具调用"""
         # ctx = app.request_context
         if name == "rag-search":
@@ -30,7 +32,7 @@ def create_mcp_server() -> Server:
         return [
             types.Tool(
                 name="rag-search",
-                title = 'RAG搜索',
+                title="RAG搜索",
                 description="搜索内容并生成总结",
                 inputSchema={
                     "type": "object",
@@ -49,7 +51,7 @@ def create_mcp_server() -> Server:
     return app
 
 
-async def rag_search(query_text: str) :
+async def rag_search(query_text: str):
     dataset_id_strs = "11,12"
     dataset_ids = dataset_id_strs.split(",")
     search_type = "hybrid"
@@ -106,7 +108,3 @@ async def rag_search(query_text: str) :
     )
 
     return data
-
-
-
-