server.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import json
  2. from typing import Any, Dict, List
  3. import mcp.types as types
  4. from mcp.server.lowlevel import Server
  5. from applications.resource import get_resource_manager
  6. from applications.utils.chat import ChatClassifier
  7. from applications.utils.mysql import ContentChunks, Contents, ChatResult
  8. from routes.buleprint import query_search
  9. def create_mcp_server() -> Server:
  10. """创建并配置MCP服务器"""
  11. app = Server("mcp-rag-server")
  12. @app.call_tool()
  13. async def call_tool(name: str, arguments: Dict[str, Any]) -> List[types.TextContent]:
  14. """处理工具调用"""
  15. # ctx = app.request_context
  16. if name == "rag-search":
  17. data = await rag_search(arguments["query_text"])
  18. result = json.dumps(data, ensure_ascii=False, indent=2)
  19. else:
  20. raise ValueError(f"Unknown tool: {name}")
  21. return [types.TextContent(type="text", text=result)]
  22. @app.list_tools()
  23. async def list_tools() -> List[types.Tool]:
  24. return [
  25. types.Tool(
  26. name="rag-search",
  27. title = 'RAG搜索',
  28. description="搜索内容并生成总结",
  29. inputSchema={
  30. "type": "object",
  31. "properties": {
  32. "query_text": {
  33. "type": "string",
  34. "description": "用户输入的查询文本",
  35. }
  36. },
  37. "required": ["query_text"], # 只强制 query_text 必填
  38. "additionalProperties": False,
  39. },
  40. ),
  41. ]
  42. return app
  43. async def rag_search(query_text: str) :
  44. dataset_id_strs = "11,12"
  45. dataset_ids = dataset_id_strs.split(",")
  46. search_type = "hybrid"
  47. query_results = await query_search(
  48. query_text=query_text,
  49. filters={"dataset_id": dataset_ids},
  50. search_type=search_type,
  51. )
  52. resource = get_resource_manager()
  53. content_chunk_mapper = ContentChunks(resource.mysql_client)
  54. contents_mapper = Contents(resource.mysql_client)
  55. chat_result_mapper = ChatResult(resource.mysql_client)
  56. res = []
  57. for result in query_results["results"]:
  58. content_chunks = await content_chunk_mapper.select_chunk_content(
  59. doc_id=result["doc_id"], chunk_id=result["chunk_id"]
  60. )
  61. contents = await contents_mapper.select_content_by_doc_id(result["doc_id"])
  62. if not content_chunks:
  63. return {"status_code": 500, "detail": "content_chunk not found", "data": {}}
  64. if not contents:
  65. return {"status_code": 500, "detail": "contents not found", "data": {}}
  66. content_chunk = content_chunks[0]
  67. content = contents[0]
  68. res.append(
  69. {
  70. "contentChunk": content_chunk["text"],
  71. "contentSummary": content_chunk["summary"],
  72. "content": content["text"],
  73. "score": result["score"],
  74. }
  75. )
  76. chat_classifier = ChatClassifier()
  77. chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
  78. data = {
  79. "result": chat_res["summary"],
  80. "status": chat_res["status"],
  81. "relevance_score": chat_res["relevance_score"],
  82. }
  83. await chat_result_mapper.insert_chat_result(
  84. query_text,
  85. dataset_id_strs,
  86. json.dumps(res, ensure_ascii=False),
  87. chat_res["summary"],
  88. chat_res["relevance_score"],
  89. chat_res["status"],
  90. )
  91. return data