server.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import json
  2. import logging
  3. from typing import Any, Dict, List
  4. import mcp.types as types
  5. from mcp.server.lowlevel import Server
  6. from applications.config import ES_HOSTS, ELASTIC_SEARCH_INDEX, ES_PASSWORD, MILVUS_CONFIG
  7. from applications.resource import get_resource_manager, init_resource_manager
  8. from applications.utils.chat import ChatClassifier
  9. from applications.utils.mysql import ContentChunks, Contents, ChatResult
  10. from routes.buleprint import query_search
  11. # 配置日志
  12. logging.basicConfig(
  13. level=logging.INFO,
  14. format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
  15. )
  16. logger = logging.getLogger(__name__)
  17. # 初始化资源管理器
  18. resource_manager = init_resource_manager(
  19. es_hosts=ES_HOSTS,
  20. es_index=ELASTIC_SEARCH_INDEX,
  21. es_password=ES_PASSWORD,
  22. milvus_config=MILVUS_CONFIG,
  23. )
  24. def create_mcp_server() -> Server:
  25. """创建并配置MCP服务器"""
  26. app = Server("mcp-rag-server")
  27. @app.call_tool()
  28. async def call_tool(name: str, arguments: Dict[str, Any]) -> List[types.TextContent]:
  29. """处理工具调用"""
  30. # ctx = app.request_context
  31. if name == "rag-search":
  32. data = await rag_search(arguments["query_text"])
  33. result = json.dumps(data, ensure_ascii=False, indent=2)
  34. else:
  35. raise ValueError(f"Unknown tool: {name}")
  36. return [types.TextContent(type="text", text=result)]
  37. @app.list_tools()
  38. async def list_tools() -> List[types.Tool]:
  39. return [
  40. types.Tool(
  41. name="rag-search",
  42. title = 'RAG搜索',
  43. description="搜索内容并生成总结",
  44. inputSchema={
  45. "type": "object",
  46. "properties": {
  47. "query_text": {
  48. "type": "string",
  49. "description": "用户输入的查询文本",
  50. }
  51. },
  52. "required": ["query_text"], # 只强制 query_text 必填
  53. "additionalProperties": False,
  54. },
  55. ),
  56. ]
  57. return app
  58. async def rag_search(query_text: str) :
  59. dataset_id_strs = "11,12"
  60. dataset_ids = dataset_id_strs.split(",")
  61. search_type = "hybrid"
  62. query_results = await query_search(
  63. query_text=query_text,
  64. filters={"dataset_id": dataset_ids},
  65. search_type=search_type,
  66. )
  67. resource = get_resource_manager()
  68. content_chunk_mapper = ContentChunks(resource.mysql_client)
  69. contents_mapper = Contents(resource.mysql_client)
  70. chat_result_mapper = ChatResult(resource.mysql_client)
  71. res = []
  72. for result in query_results["results"]:
  73. content_chunks = await content_chunk_mapper.select_chunk_content(
  74. doc_id=result["doc_id"], chunk_id=result["chunk_id"]
  75. )
  76. contents = await contents_mapper.select_content_by_doc_id(result["doc_id"])
  77. if not content_chunks:
  78. return {"status_code": 500, "detail": "content_chunk not found", "data": {}}
  79. if not contents:
  80. return {"status_code": 500, "detail": "contents not found", "data": {}}
  81. content_chunk = content_chunks[0]
  82. content = contents[0]
  83. res.append(
  84. {
  85. "contentChunk": content_chunk["text"],
  86. "contentSummary": content_chunk["summary"],
  87. "content": content["text"],
  88. "score": result["score"],
  89. }
  90. )
  91. chat_classifier = ChatClassifier()
  92. chat_res = await chat_classifier.chat_with_deepseek(query_text, res)
  93. data = {
  94. "result": chat_res["summary"],
  95. "status": chat_res["status"],
  96. "score": chat_res["score"],
  97. }
  98. await chat_result_mapper.insert_chat_result(
  99. query_text,
  100. dataset_id_strs,
  101. json.dumps(res, ensure_ascii=False),
  102. chat_res["summary"],
  103. chat_res["relevance_score"],
  104. chat_res["status"],
  105. )
  106. return data