浏览代码

增加问题分解

xueyiming 8 小时之前
父节点
当前提交
36eed94dfe
共有 4 个文件被更改,包括 177 次插入80 次删除
  1. 31 0
      applications/utils/chat/rag_chat_agent.py
  2. 1 1
      applications/utils/mysql/mapper.py
  3. 71 38
      mcp_server/server.py
  4. 74 41
      routes/buleprint.py

+ 31 - 0
applications/utils/chat/rag_chat_agent.py

@@ -141,3 +141,34 @@ class RAGChatAgent:
             model="DeepSeek-R1", prompt=prompt, output_type="json"
         )
         return response
+
+    @staticmethod
+    def split_query_prompt(query):
+        prompt = f"""
+        请将以下问题拆解成最多3个宽泛的子问题。要求:
+
+        1. 子问题应该围绕原始问题本身展开,但需要更宽泛,避免过于细化。
+        2. 每个子问题应关注原始问题的核心内容,但从不同的角度、层面或维度去思考。
+        3. 子问题要求和原问题同属同一类别
+
+        原始问题:
+        {query}
+
+        请按照以下JSON格式返回结果:
+        {{
+            "original_question": "原始问题",
+            "split_questions": [
+                "第一个宽泛问题", "第二个宽泛问题", "第三个宽泛问题"
+            ]
+        }}
+
+        请确保返回的内容是纯JSON格式,不要包含其他任何文字。
+        """
+        return prompt
+
+    async def split_query(self, query):
+        prompt = self.split_query_prompt(query)
+        response = await fetch_deepseek_completion(
+            model="DeepSeek-V3", prompt=prompt, output_type="json"
+        )
+        return response

+ 1 - 1
applications/utils/mysql/mapper.py

@@ -48,7 +48,7 @@ class ChatResult(BaseMySQLClient):
         ai_source,
         ai_status,
         final_result,
-        study_task_id,
+        study_task_id=None,
         is_web=None,
     ):
         query = """

+ 71 - 38
mcp_server/server.py

@@ -53,43 +53,76 @@ def create_mcp_server() -> Server:
     return app
 
 
+async def process_question(question, query_text, rag_chat_agent):
+    try:
+        dataset_id_strs = "11,12"
+        dataset_ids = dataset_id_strs.split(",")
+        search_type = "hybrid"
+
+        # 执行查询任务
+        query_results = await query_search(
+            query_text=question,
+            filters={"dataset_id": dataset_ids},
+            search_type=search_type,
+        )
+
+        resource = get_resource_manager()
+        chat_result_mapper = ChatResult(resource.mysql_client)
+
+        # 异步执行 chat 与 deepseek 的对话
+        chat_result = await rag_chat_agent.chat_with_deepseek(question, query_results)
+
+        # # 判断是否需要执行 study
+        study_task_id = None
+        if chat_result["status"] == 0:
+            study_task_id = study(question)["task_id"]
+
+        # 异步获取 LLM 搜索结果
+        llm_search_result = await rag_chat_agent.llm_search(question)
+
+        # 执行决策逻辑
+        decision = await rag_chat_agent.make_decision(chat_result, llm_search_result)
+
+        # 构建返回的数据
+        data = {
+            "query": question,
+            "result": decision["result"],
+            "status": decision["status"],
+            "relevance_score": decision["relevance_score"],
+        }
+
+        # 插入数据库
+        await chat_result_mapper.insert_chat_result(
+            question,
+            dataset_id_strs,
+            json.dumps(query_results, ensure_ascii=False),
+            chat_result["summary"],
+            chat_result["relevance_score"],
+            chat_result["status"],
+            llm_search_result["answer"],
+            llm_search_result["source"],
+            llm_search_result["status"],
+            decision["result"],
+            study_task_id,
+        )
+        return data
+    except Exception as e:
+        print(f"Error processing question: {question}. Error: {str(e)}")
+        return {"query": question, "error": str(e)}
+
+
 async def rag_search(query_text: str):
-    dataset_id_strs = "11,12"
-    dataset_ids = dataset_id_strs.split(",")
-    search_type = "hybrid"
-
-    query_results = await query_search(
-        query_text=query_text,
-        filters={"dataset_id": dataset_ids},
-        search_type=search_type,
-    )
-
-    resource = get_resource_manager()
-    chat_result_mapper = ChatResult(resource.mysql_client)
     rag_chat_agent = RAGChatAgent()
-    chat_result = await rag_chat_agent.chat_with_deepseek(query_text, query_results)
-    study_task_id = None
-    if chat_result["status"] == 0:
-        study_task_id = study(query_text)['task_id']
-    llm_search_result = await rag_chat_agent.llm_search(query_text)
-    decision = await rag_chat_agent.make_decision(chat_result, llm_search_result)
-    data = {
-        "result": decision["result"],
-        "status": decision["status"],
-        "relevance_score": decision["relevance_score"],
-    }
-    await chat_result_mapper.insert_chat_result(
-        query_text,
-        dataset_id_strs,
-        json.dumps(query_results, ensure_ascii=False),
-        chat_result["summary"],
-        chat_result["relevance_score"],
-        chat_result["status"],
-        llm_search_result["answer"],
-        llm_search_result["source"],
-        llm_search_result["status"],
-        decision["result"],
-        study_task_id
-    )
-
-    return data
+    spilt_query = await rag_chat_agent.split_query(query_text)
+    split_questions = spilt_query["split_questions"]
+    split_questions.append(query_text)
+
+    # 使用asyncio.gather并行处理每个问题
+    tasks = [
+        process_question(question, query_text, rag_chat_agent)
+        for question in split_questions
+    ]
+
+    # 等待所有任务完成并收集结果
+    data_list = await asyncio.gather(*tasks)
+    return data_list

+ 74 - 41
routes/buleprint.py

@@ -414,9 +414,9 @@ async def chat():
 
     rag_chat_agent = RAGChatAgent()
     chat_result = await rag_chat_agent.chat_with_deepseek(query_text, query_results)
-    study_task_id = None
-    if chat_result["status"] == 0:
-        study_task_id = study(query_text)['task_id']
+    # study_task_id = None
+    # if chat_result["status"] == 0:
+    #     study_task_id = study(query_text)['task_id']
     llm_search = await rag_chat_agent.llm_search(query_text)
     decision = await rag_chat_agent.make_decision(chat_result, llm_search)
     data = {
@@ -436,7 +436,6 @@ async def chat():
         llm_search["source"],
         llm_search["status"],
         decision["result"],
-        study_task_id,
         is_web=1,
     )
     return jsonify({"status_code": 200, "detail": "success", "data": data})
@@ -518,44 +517,78 @@ async def delete_task():
 async def rag_search():
     body = await request.get_json()
     query_text = body.get("queryText")
-    dataset_id_strs = "11,12"
-    dataset_ids = dataset_id_strs.split(",")
-    search_type = "hybrid"
-
-    query_results = await query_search(
-        query_text=query_text,
-        filters={"dataset_id": dataset_ids},
-        search_type=search_type,
-        limit=5,
-    )
-    resource = get_resource_manager()
-    chat_result_mapper = ChatResult(resource.mysql_client)
     rag_chat_agent = RAGChatAgent()
-    chat_result = await rag_chat_agent.chat_with_deepseek(query_text, query_results)
-    study_task_id = None
-    if chat_result["status"] == 0:
-        study_task_id = study(query_text)['task_id']
-    llm_search = await rag_chat_agent.llm_search(query_text)
-    decision = await rag_chat_agent.make_decision(chat_result, llm_search)
-    data = {
-        "result": decision["result"],
-        "status": decision["status"],
-        "relevance_score": decision["relevance_score"],
-    }
-    await chat_result_mapper.insert_chat_result(
-        query_text,
-        dataset_id_strs,
-        json.dumps(query_results, ensure_ascii=False),
-        chat_result["summary"],
-        chat_result["relevance_score"],
-        chat_result["status"],
-        llm_search["answer"],
-        llm_search["source"],
-        llm_search["status"],
-        decision["result"],
-        study_task_id
-    )
-    return jsonify({"status_code": 200, "detail": "success", "data": data})
+    spilt_query = await rag_chat_agent.split_query(query_text)
+    split_questions = spilt_query["split_questions"]
+    split_questions.append(query_text)
+
+    # 使用asyncio.gather并行处理每个问题
+    tasks = [
+        process_question(question, query_text, rag_chat_agent)
+        for question in split_questions
+    ]
+
+    # 等待所有任务完成并收集结果
+    data_list = await asyncio.gather(*tasks)
+    return jsonify({"status_code": 200, "detail": "success", "data": data_list})
+
+
+async def process_question(question, query_text, rag_chat_agent):
+    try:
+        dataset_id_strs = "11,12"
+        dataset_ids = dataset_id_strs.split(",")
+        search_type = "hybrid"
+
+        # 执行查询任务
+        query_results = await query_search(
+            query_text=question,
+            filters={"dataset_id": dataset_ids},
+            search_type=search_type,
+        )
+
+        resource = get_resource_manager()
+        chat_result_mapper = ChatResult(resource.mysql_client)
+
+        # 异步执行 chat 与 deepseek 的对话
+        chat_result = await rag_chat_agent.chat_with_deepseek(question, query_results)
+
+        # # 判断是否需要执行 study
+        study_task_id = None
+        if chat_result["status"] == 0:
+            study_task_id = study(question)["task_id"]
+
+        # 异步获取 LLM 搜索结果
+        llm_search_result = await rag_chat_agent.llm_search(question)
+
+        # 执行决策逻辑
+        decision = await rag_chat_agent.make_decision(chat_result, llm_search_result)
+
+        # 构建返回的数据
+        data = {
+            "query": question,
+            "result": decision["result"],
+            "status": decision["status"],
+            "relevance_score": decision["relevance_score"],
+        }
+
+        # 插入数据库
+        await chat_result_mapper.insert_chat_result(
+            question,
+            dataset_id_strs,
+            json.dumps(query_results, ensure_ascii=False),
+            chat_result["summary"],
+            chat_result["relevance_score"],
+            chat_result["status"],
+            llm_search_result["answer"],
+            llm_search_result["source"],
+            llm_search_result["status"],
+            decision["result"],
+            study_task_id,
+        )
+        return data
+    except Exception as e:
+        print(f"Error processing question: {question}. Error: {str(e)}")
+        return {"query": question, "error": str(e)}
 
 
 @server_bp.route("/chat/history", methods=["GET"])