server.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import asyncio
  2. import json
  3. from typing import Any, Dict, List
  4. import mcp.types as types
  5. from mcp.server.lowlevel import Server
  6. from applications.resource import get_resource_manager
  7. from applications.utils.chat import RAGChatAgent
  8. from applications.utils.mysql import ChatResult
  9. from applications.api.qwen import QwenClient
  10. from applications.utils.spider.study import study
  11. from applications.utils.task.async_task import query_search
  12. def create_mcp_server() -> Server:
  13. """创建并配置MCP服务器"""
  14. app = Server("mcp-rag-server")
  15. @app.call_tool()
  16. async def call_tool(
  17. name: str, arguments: Dict[str, Any]
  18. ) -> List[types.TextContent]:
  19. """处理工具调用"""
  20. # ctx = app.request_context
  21. if name == "rag-search":
  22. data = await rag_search(arguments["query_text"])
  23. result = json.dumps(data, ensure_ascii=False, indent=2)
  24. else:
  25. raise ValueError(f"Unknown tool: {name}")
  26. return [types.TextContent(type="text", text=result)]
  27. @app.list_tools()
  28. async def list_tools() -> List[types.Tool]:
  29. return [
  30. types.Tool(
  31. name="rag-search",
  32. title="RAG搜索",
  33. description="搜索内容并生成总结",
  34. inputSchema={
  35. "type": "object",
  36. "properties": {
  37. "query_text": {
  38. "type": "string",
  39. "description": "用户输入的查询文本",
  40. }
  41. },
  42. "required": ["query_text"], # 只强制 query_text 必填
  43. "additionalProperties": False,
  44. },
  45. ),
  46. ]
  47. return app
  48. async def process_question(question, query_text, rag_chat_agent):
  49. try:
  50. dataset_id_strs = "11,12"
  51. dataset_ids = dataset_id_strs.split(",")
  52. search_type = "hybrid"
  53. # 执行查询任务
  54. query_results = await query_search(
  55. query_text=question,
  56. filters={"dataset_id": dataset_ids},
  57. search_type=search_type,
  58. )
  59. resource = get_resource_manager()
  60. chat_result_mapper = ChatResult(resource.mysql_client)
  61. # 异步执行 chat 与 deepseek 的对话
  62. chat_result = await rag_chat_agent.chat_with_deepseek(question, query_results)
  63. # # 判断是否需要执行 study
  64. study_task_id = None
  65. if chat_result["status"] == 0:
  66. study_task_id = study(question)["task_id"]
  67. qwen_client = QwenClient()
  68. llm_search = qwen_client.search_and_chat(
  69. user_prompt=question, search_strategy="agent"
  70. )
  71. # 执行决策逻辑
  72. decision = await rag_chat_agent.make_decision(question, chat_result, llm_search)
  73. # 构建返回的数据
  74. data = {
  75. "query": question,
  76. "result": decision["result"],
  77. "status": decision["status"],
  78. "relevance_score": decision["relevance_score"],
  79. }
  80. # 插入数据库
  81. await chat_result_mapper.insert_chat_result(
  82. question,
  83. dataset_id_strs,
  84. json.dumps(query_results, ensure_ascii=False),
  85. chat_result["summary"],
  86. chat_result["relevance_score"],
  87. chat_result["status"],
  88. llm_search["content"],
  89. json.dumps(llm_search["search_results"], ensure_ascii=False),
  90. 1,
  91. decision["result"],
  92. study_task_id,
  93. )
  94. return data
  95. except Exception as e:
  96. print(f"Error processing question: {question}. Error: {str(e)}")
  97. return {"query": question, "error": str(e)}
  98. async def rag_search(query_text: str):
  99. rag_chat_agent = RAGChatAgent()
  100. split_questions = []
  101. # spilt_query = await rag_chat_agent.split_query(query_text)
  102. # split_questions = spilt_query["split_questions"]
  103. split_questions.append(query_text)
  104. # 使用asyncio.gather并行处理每个问题
  105. tasks = [
  106. process_question(question, query_text, rag_chat_agent)
  107. for question in split_questions
  108. ]
  109. # 等待所有任务完成并收集结果
  110. data_list = await asyncio.gather(*tasks)
  111. return data_list