123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 |
- import asyncio
- import json
- from typing import Any, Dict, List
- import mcp.types as types
- from mcp.server.lowlevel import Server
- from applications.resource import get_resource_manager
- from applications.utils.chat import RAGChatAgent
- from applications.utils.mysql import ChatResult
- from applications.utils.spider.study import study
- from routes.buleprint import query_search
- def create_mcp_server() -> Server:
- """创建并配置MCP服务器"""
- app = Server("mcp-rag-server")
- @app.call_tool()
- async def call_tool(
- name: str, arguments: Dict[str, Any]
- ) -> List[types.TextContent]:
- """处理工具调用"""
- # ctx = app.request_context
- if name == "rag-search":
- data = await rag_search(arguments["query_text"])
- result = json.dumps(data, ensure_ascii=False, indent=2)
- else:
- raise ValueError(f"Unknown tool: {name}")
- return [types.TextContent(type="text", text=result)]
- @app.list_tools()
- async def list_tools() -> List[types.Tool]:
- return [
- types.Tool(
- name="rag-search",
- title="RAG搜索",
- description="搜索内容并生成总结",
- inputSchema={
- "type": "object",
- "properties": {
- "query_text": {
- "type": "string",
- "description": "用户输入的查询文本",
- }
- },
- "required": ["query_text"], # 只强制 query_text 必填
- "additionalProperties": False,
- },
- ),
- ]
- return app
- async def process_question(question, query_text, rag_chat_agent):
- try:
- dataset_id_strs = "11,12"
- dataset_ids = dataset_id_strs.split(",")
- search_type = "hybrid"
- # 执行查询任务
- query_results = await query_search(
- query_text=question,
- filters={"dataset_id": dataset_ids},
- search_type=search_type,
- )
- resource = get_resource_manager()
- chat_result_mapper = ChatResult(resource.mysql_client)
- # 异步执行 chat 与 deepseek 的对话
- chat_result = await rag_chat_agent.chat_with_deepseek(question, query_results)
- # # 判断是否需要执行 study
- study_task_id = None
- if chat_result["status"] == 0:
- study_task_id = study(question)["task_id"]
- # 异步获取 LLM 搜索结果
- llm_search_result = await rag_chat_agent.llm_search(question)
- # 执行决策逻辑
- decision = await rag_chat_agent.make_decision(chat_result, llm_search_result)
- # 构建返回的数据
- data = {
- "query": question,
- "result": decision["result"],
- "status": decision["status"],
- "relevance_score": decision["relevance_score"],
- }
- # 插入数据库
- await chat_result_mapper.insert_chat_result(
- question,
- dataset_id_strs,
- json.dumps(query_results, ensure_ascii=False),
- chat_result["summary"],
- chat_result["relevance_score"],
- chat_result["status"],
- llm_search_result["answer"],
- llm_search_result["source"],
- llm_search_result["status"],
- decision["result"],
- study_task_id,
- )
- return data
- except Exception as e:
- print(f"Error processing question: {question}. Error: {str(e)}")
- return {"query": question, "error": str(e)}
- async def rag_search(query_text: str):
- rag_chat_agent = RAGChatAgent()
- spilt_query = await rag_chat_agent.split_query(query_text)
- split_questions = spilt_query["split_questions"]
- split_questions.append(query_text)
- # 使用asyncio.gather并行处理每个问题
- tasks = [
- process_question(question, query_text, rag_chat_agent)
- for question in split_questions
- ]
- # 等待所有任务完成并收集结果
- data_list = await asyncio.gather(*tasks)
- return data_list
|