server.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import json
  2. from typing import Any, Dict, List
  3. import mcp.types as types
  4. from mcp.server.lowlevel import Server
  5. from applications.resource import get_resource_manager
  6. from applications.utils.chat import ChatClassifier
  7. from applications.utils.mysql import ContentChunks, Contents, ChatResult
  8. from routes.buleprint import query_search
  9. def create_mcp_server() -> Server:
  10. """创建并配置MCP服务器"""
  11. app = Server("mcp-rag-server")
  12. @app.call_tool()
  13. async def call_tool(
  14. name: str, arguments: Dict[str, Any]
  15. ) -> List[types.TextContent]:
  16. """处理工具调用"""
  17. # ctx = app.request_context
  18. if name == "rag-search":
  19. data = await rag_search(arguments["query_text"])
  20. result = json.dumps(data, ensure_ascii=False, indent=2)
  21. else:
  22. raise ValueError(f"Unknown tool: {name}")
  23. return [types.TextContent(type="text", text=result)]
  24. @app.list_tools()
  25. async def list_tools() -> List[types.Tool]:
  26. return [
  27. types.Tool(
  28. name="rag-search",
  29. title="RAG搜索",
  30. description="搜索内容并生成总结",
  31. inputSchema={
  32. "type": "object",
  33. "properties": {
  34. "query_text": {
  35. "type": "string",
  36. "description": "用户输入的查询文本",
  37. }
  38. },
  39. "required": ["query_text"], # 只强制 query_text 必填
  40. "additionalProperties": False,
  41. },
  42. ),
  43. ]
  44. return app
  45. async def rag_search(query_text: str):
  46. dataset_id_strs = "11,12"
  47. dataset_ids = dataset_id_strs.split(",")
  48. search_type = "hybrid"
  49. query_results = await query_search(
  50. query_text=query_text,
  51. filters={"dataset_id": dataset_ids},
  52. search_type=search_type,
  53. )
  54. resource = get_resource_manager()
  55. content_chunk_mapper = ContentChunks(resource.mysql_client)
  56. contents_mapper = Contents(resource.mysql_client)
  57. chat_result_mapper = ChatResult(resource.mysql_client)
  58. res = []
  59. for result in query_results["results"]:
  60. content_chunks = await content_chunk_mapper.select_chunk_content(
  61. doc_id=result["doc_id"], chunk_id=result["chunk_id"]
  62. )
  63. contents = await contents_mapper.select_content_by_doc_id(result["doc_id"])
  64. if not content_chunks:
  65. return {"status_code": 500, "detail": "content_chunk not found", "data": {}}
  66. if not contents:
  67. return {"status_code": 500, "detail": "contents not found", "data": {}}
  68. content_chunk = content_chunks[0]
  69. content = contents[0]
  70. res.append(
  71. {
  72. "contentChunk": content_chunk["text"],
  73. "contentSummary": content_chunk["summary"],
  74. "content": content["text"],
  75. "score": result["score"],
  76. }
  77. )
  78. chat_classifier = ChatClassifier()
  79. chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
  80. data = {
  81. "result": chat_res["summary"],
  82. "status": chat_res["status"],
  83. "relevance_score": chat_res["relevance_score"],
  84. }
  85. await chat_result_mapper.insert_chat_result(
  86. query_text,
  87. dataset_id_strs,
  88. json.dumps(res, ensure_ascii=False),
  89. chat_res["summary"],
  90. chat_res["relevance_score"],
  91. chat_res["status"],
  92. )
  93. return data