Kaynağa Gözat

增加mcp服务

xueyiming 2 hafta önce
ebeveyn
işleme
399ab94483
5 değiştirilmiş dosya ile 188 ekleme ve 2 silme
  1. 1 1
      Dockerfile
  2. 0 0
      mcp_server/__init__.py
  3. 129 0
      mcp_server/server.py
  4. 2 1
      requirements.txt
  5. 56 0
      vector_app.py

+ 1 - 1
Dockerfile

@@ -27,7 +27,7 @@ RUN pip install --no-cache-dir -r requirements.txt -i https://mirrors.aliyun.com
 COPY . .
 
 # 暴露端口
-EXPOSE 8001
+EXPOSE 8001, 8002
 
 # 启动命令
 CMD ["hypercorn", "vector_app:app", "--config", "config.toml"]

+ 0 - 0
mcp_server/__init__.py


+ 129 - 0
mcp_server/server.py

@@ -0,0 +1,129 @@
+import json
+import logging
+from typing import Any, Dict, List
+
+import mcp.types as types
+from mcp.server.lowlevel import Server
+
+from applications.config import ES_HOSTS, ELASTIC_SEARCH_INDEX, ES_PASSWORD, MILVUS_CONFIG
+from applications.resource import get_resource_manager, init_resource_manager
+from applications.utils.chat import ChatClassifier
+from applications.utils.mysql import ContentChunks, Contents, ChatResult
+from routes.buleprint import query_search
+
+# 配置日志
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
+)
+logger = logging.getLogger(__name__)
+
+# 初始化资源管理器
+resource_manager = init_resource_manager(
+    es_hosts=ES_HOSTS,
+    es_index=ELASTIC_SEARCH_INDEX,
+    es_password=ES_PASSWORD,
+    milvus_config=MILVUS_CONFIG,
+)
+
+
+def create_mcp_server() -> Server:
+    """创建并配置MCP服务器"""
+    app = Server("mcp-rag-server")
+
+    @app.call_tool()
+    async def call_tool(name: str, arguments: Dict[str, Any]) -> List[types.TextContent]:
+        """处理工具调用"""
+        # ctx = app.request_context
+        if name == "chat-detail":
+            data = await chat_detail(arguments["query_text"])
+            result = json.dumps(data, ensure_ascii=False, indent=2)
+        else:
+            raise ValueError(f"Unknown tool: {name}")
+        return [types.TextContent(type="text", text=result)]
+
+    @app.list_tools()
+    async def list_tools() -> List[types.Tool]:
+        return [
+            types.Tool(
+                name="chat-detail",
+                title = 'RAG搜索',
+                description="搜索内容并生成总结",
+                inputSchema={
+                    "type": "object",
+                    "properties": {
+                        "query_text": {
+                            "type": "string",
+                            "description": "用户输入的查询文本",
+                        }
+                    },
+                    "required": ["query_text"],  # 只强制 query_text 必填
+                    "additionalProperties": False,
+                },
+            ),
+        ]
+
+    return app
+
+
+async def chat_detail(query_text: str) :
+    dataset_id_strs = "11,12"
+    dataset_ids = dataset_id_strs.split(",")
+    search_type = "hybrid"
+
+    query_results = await query_search(
+        query_text=query_text,
+        filters={"dataset_id": dataset_ids},
+        search_type=search_type,
+    )
+
+    resource = get_resource_manager()
+    content_chunk_mapper = ContentChunks(resource.mysql_client)
+    contents_mapper = Contents(resource.mysql_client)
+    chat_result_mapper = ChatResult(resource.mysql_client)
+
+    res = []
+    for result in query_results["results"]:
+        content_chunks = await content_chunk_mapper.select_chunk_content(
+            doc_id=result["doc_id"], chunk_id=result["chunk_id"]
+        )
+        contents = await contents_mapper.select_content_by_doc_id(result["doc_id"])
+        if not content_chunks:
+            return {"status_code": 500, "detail": "content_chunk not found", "data": {}}
+        if not contents:
+            return {"status_code": 500, "detail": "contents not found", "data": {}}
+
+        content_chunk = content_chunks[0]
+        content = contents[0]
+        res.append(
+            {
+                "contentChunk": content_chunk["text"],
+                "contentSummary": content_chunk["summary"],
+                "content": content["text"],
+                "score": result["score"],
+            }
+        )
+
+    chat_classifier = ChatClassifier()
+    chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
+
+    data = {
+        "result": chat_res["summary"],
+        "status": chat_res["status"],
+        "metaData": res,
+    }
+
+    await chat_result_mapper.insert_chat_result(
+        query_text,
+        dataset_id_strs,
+        json.dumps(data, ensure_ascii=False),
+        chat_res["summary"],
+        chat_res["relevance_score"],
+        chat_res["status"],
+    )
+
+    return data
+
+
+
+

+ 2 - 1
requirements.txt

@@ -20,4 +20,5 @@ quart-cors==0.8.0
 tiktoken==0.11.0
 uvloop==0.21.0
 elasticsearch==8.17.2
-scikit-learn==1.7.2
+scikit-learn==1.7.2
+mcp==1.14.1

+ 56 - 0
vector_app.py

@@ -1,10 +1,21 @@
+import asyncio
+import contextlib
+from typing import AsyncIterator
+
+import anyio
 import jieba
+from uvicorn import Config, Server
+from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
 from quart import Quart
+from starlette.applications import Starlette
+from starlette.routing import Mount
+from starlette.types import Receive, Send, Scope
 
 from applications.config import LOCAL_MODEL_CONFIG, DEFAULT_MODEL
 from applications.config import ES_HOSTS, ES_PASSWORD, ELASTIC_SEARCH_INDEX
 from applications.config import MILVUS_CONFIG
 from applications.resource import init_resource_manager
+from mcp_server.server import create_mcp_server, logger
 
 app = Quart(__name__)
 
@@ -19,12 +30,57 @@ resource_manager = init_resource_manager(
 )
 
 
+# 确保没有重复事件循环
+async def start_mcp_server(host: str, port: int, json_response: bool):
+    try:
+        # 创建应用
+        app = create_mcp_server()
+        session_manager = StreamableHTTPSessionManager(
+            app=app,
+            event_store=None,
+            json_response=json_response,
+            stateless=True,
+        )
+        logger.info("Session Manager created and ready.")
+
+        async def handle_streamable_http(scope: Scope, receive: Receive, send: Send) -> None:
+            try:
+                logger.info(f"Started processing request: {scope}")
+                await session_manager.handle_request(scope, receive, send)
+                logger.info(f"Finished processing request: {scope}")
+            except anyio.ClosedResourceError:
+                logger.error("Stream closed unexpectedly during request.")
+            except Exception as e:
+                logger.error(f"Unexpected error: {e}")
+
+        # 启动应用生命周期管理
+        @contextlib.asynccontextmanager
+        async def lifespan(starlette_app: Starlette) -> AsyncIterator[None]:
+            async with session_manager.run():
+                yield
+
+        # 配置Starlette应用
+        starlette_app = Starlette(
+            debug=True,
+            routes=[Mount("/mcp", app=handle_streamable_http)],
+            lifespan=lifespan,
+        )
+        config = Config(app=starlette_app, host=host, port=port)
+        server = Server(config=config)
+        await server.serve()
+
+    except Exception as e:
+        logger.error(f"Error in start_mcp_server: {e}")
+        raise
+
 @app.before_serving
 async def startup():
     await resource_manager.startup()
     print("Resource manager is ready.")
     jieba.initialize()
     print("Jieba dictionary loaded successfully")
+    # 使用 asyncio.create_task 来异步启动 mcp_server
+    asyncio.create_task(start_mcp_server("0.0.0.0", 8002, json_response=True))
 
 
 @app.after_serving