server.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import json
  2. from typing import Any, Dict, List
  3. import mcp.types as types
  4. from mcp.server.lowlevel import Server
  5. from applications.resource import get_resource_manager
  6. from applications.utils.chat import ChatClassifier
  7. from applications.utils.mysql import ContentChunks, Contents, ChatResult
  8. from routes.buleprint import query_search
  9. import json
  10. from typing import Any, Dict, List
  11. import mcp.types as types
  12. from mcp.server.lowlevel import Server
  13. from applications.resource import get_resource_manager
  14. from applications.utils.chat import ChatClassifier
  15. from applications.utils.mysql import ContentChunks, Contents, ChatResult
  16. from routes.buleprint import query_search
  17. def create_mcp_server() -> Server:
  18. """创建并配置MCP服务器"""
  19. app = Server("mcp-rag-server")
  20. @app.call_tool()
  21. async def call_tool(name: str, arguments: Dict[str, Any]) -> List[types.TextContent]:
  22. """处理工具调用"""
  23. # ctx = app.request_context
  24. if name == "rag-search":
  25. data = await rag_search(arguments["query_text"])
  26. result = json.dumps(data, ensure_ascii=False, indent=2)
  27. else:
  28. raise ValueError(f"Unknown tool: {name}")
  29. return [types.TextContent(type="text", text=result)]
  30. @app.list_tools()
  31. async def list_tools() -> List[types.Tool]:
  32. return [
  33. types.Tool(
  34. name="rag-search",
  35. title = 'RAG搜索',
  36. description="搜索内容并生成总结",
  37. inputSchema={
  38. "type": "object",
  39. "properties": {
  40. "query_text": {
  41. "type": "string",
  42. "description": "用户输入的查询文本",
  43. }
  44. },
  45. "required": ["query_text"], # 只强制 query_text 必填
  46. "additionalProperties": False,
  47. },
  48. ),
  49. ]
  50. return app
  51. async def rag_search(query_text: str) :
  52. dataset_id_strs = "11,12"
  53. dataset_ids = dataset_id_strs.split(",")
  54. search_type = "hybrid"
  55. query_results = await query_search(
  56. query_text=query_text,
  57. filters={"dataset_id": dataset_ids},
  58. search_type=search_type,
  59. )
  60. resource = get_resource_manager()
  61. content_chunk_mapper = ContentChunks(resource.mysql_client)
  62. contents_mapper = Contents(resource.mysql_client)
  63. chat_result_mapper = ChatResult(resource.mysql_client)
  64. res = []
  65. for result in query_results["results"]:
  66. content_chunks = await content_chunk_mapper.select_chunk_content(
  67. doc_id=result["doc_id"], chunk_id=result["chunk_id"]
  68. )
  69. contents = await contents_mapper.select_content_by_doc_id(result["doc_id"])
  70. if not content_chunks:
  71. return {"status_code": 500, "detail": "content_chunk not found", "data": {}}
  72. if not contents:
  73. return {"status_code": 500, "detail": "contents not found", "data": {}}
  74. content_chunk = content_chunks[0]
  75. content = contents[0]
  76. res.append(
  77. {
  78. "contentChunk": content_chunk["text"],
  79. "contentSummary": content_chunk["summary"],
  80. "content": content["text"],
  81. "score": result["score"],
  82. }
  83. )
  84. chat_classifier = ChatClassifier()
  85. chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
  86. data = {
  87. "result": chat_res["summary"],
  88. "status": chat_res["status"],
  89. "relevance_score": chat_res["relevance_score"],
  90. }
  91. await chat_result_mapper.insert_chat_result(
  92. query_text,
  93. dataset_id_strs,
  94. json.dumps(res, ensure_ascii=False),
  95. chat_res["summary"],
  96. chat_res["relevance_score"],
  97. chat_res["status"],
  98. )
  99. return data