Sfoglia il codice sorgente

Merge branch 'master' of https://git.yishihui.com/Server/rag_server
merge

luojunhui 2 mesi fa
parent
commit
ad3da7c918

+ 41 - 0
applications/utils/chat/rag_chat_agent.py

@@ -1,3 +1,4 @@
+import asyncio
 from typing import List
 
 from applications.config import Chunk
@@ -141,3 +142,43 @@ class RAGChatAgent:
             model="DeepSeek-R1", prompt=prompt, output_type="json"
         )
         return response
+
+    @staticmethod
+    def split_query_prompt(query):
+        prompt = f"""
+        你是一个信息检索助理,负责把用户的问题拆解为“更宽泛但仍同类、且彼此不重叠”的子问题,用于召回多样证据。
+
+        【目标】
+        - 生成 1–3 个“更宽泛”的子问题(broad questions),它们应与原问题同一类别/主题,但从不同角度扩展;避免把原问题切得更细(avoid over-specific)。
+        - 子问题之间尽量覆盖不同维度(例如:背景/原理、影响/应用、比较/趋势、方法/评估 等),减少语义重叠(≤20% 相似度)。
+
+        【必须遵守】
+        1) 与原问题同类:如果原问题是技术/科普/对比/流程类,子问题也应保持同类语气与目标。
+        2) 更宽泛:去掉过细的限制(具体数值/版本/人名/时间点/实现细节),但保留主题核心。
+        3) 去重与互补:不要改写成近义句;每个子问题关注的面不同(角度、层面或受众不同)。
+        4) 可检索性:避免抽象空话;每个子问题都应是可用于检索/召回的良好查询。
+        5) 数量自适应:若无法合理扩展到 3 个,就输出 1–2 个;不要为了凑数而重复。
+        6) 语言一致:与原问题同语言输出(中文入→中文出;英文入→英文出)。
+        7) 仅输出 JSON,严格符合下述 schema;不要输出额外文字或注释。
+
+        原始问题:
+        {query}
+
+        请按照以下JSON格式返回结果:
+        {{
+            "original_question": "原始问题",
+            "split_questions": [
+                "第一个宽泛问题", "第二个宽泛问题", "第三个宽泛问题"
+            ]
+        }}
+
+        请确保返回的内容是纯JSON格式,不要包含其他任何文字。
+        """
+        return prompt
+
+    async def split_query(self, query):
+        prompt = self.split_query_prompt(query)
+        response = await fetch_deepseek_completion(
+            model="DeepSeek-V3", prompt=prompt, output_type="json"
+        )
+        return response

+ 1 - 1
applications/utils/mysql/mapper.py

@@ -48,7 +48,7 @@ class ChatResult(BaseMySQLClient):
         ai_source,
         ai_status,
         final_result,
-        study_task_id,
+        study_task_id=None,
         is_web=None,
     ):
         query = """

+ 71 - 38
mcp_server/server.py

@@ -53,43 +53,76 @@ def create_mcp_server() -> Server:
     return app
 
 
+async def process_question(question, query_text, rag_chat_agent):
+    try:
+        dataset_id_strs = "11,12"
+        dataset_ids = dataset_id_strs.split(",")
+        search_type = "hybrid"
+
+        # 执行查询任务
+        query_results = await query_search(
+            query_text=question,
+            filters={"dataset_id": dataset_ids},
+            search_type=search_type,
+        )
+
+        resource = get_resource_manager()
+        chat_result_mapper = ChatResult(resource.mysql_client)
+
+        # 异步执行 chat 与 deepseek 的对话
+        chat_result = await rag_chat_agent.chat_with_deepseek(question, query_results)
+
+        # # 判断是否需要执行 study
+        study_task_id = None
+        if chat_result["status"] == 0:
+            study_task_id = study(question)["task_id"]
+
+        # 异步获取 LLM 搜索结果
+        llm_search_result = await rag_chat_agent.llm_search(question)
+
+        # 执行决策逻辑
+        decision = await rag_chat_agent.make_decision(chat_result, llm_search_result)
+
+        # 构建返回的数据
+        data = {
+            "query": question,
+            "result": decision["result"],
+            "status": decision["status"],
+            "relevance_score": decision["relevance_score"],
+        }
+
+        # 插入数据库
+        await chat_result_mapper.insert_chat_result(
+            question,
+            dataset_id_strs,
+            json.dumps(query_results, ensure_ascii=False),
+            chat_result["summary"],
+            chat_result["relevance_score"],
+            chat_result["status"],
+            llm_search_result["answer"],
+            llm_search_result["source"],
+            llm_search_result["status"],
+            decision["result"],
+            study_task_id,
+        )
+        return data
+    except Exception as e:
+        print(f"Error processing question: {question}. Error: {str(e)}")
+        return {"query": question, "error": str(e)}
+
+
 async def rag_search(query_text: str):
-    dataset_id_strs = "11,12"
-    dataset_ids = dataset_id_strs.split(",")
-    search_type = "hybrid"
-
-    query_results = await query_search(
-        query_text=query_text,
-        filters={"dataset_id": dataset_ids},
-        search_type=search_type,
-    )
-
-    resource = get_resource_manager()
-    chat_result_mapper = ChatResult(resource.mysql_client)
     rag_chat_agent = RAGChatAgent()
-    chat_result = await rag_chat_agent.chat_with_deepseek(query_text, query_results)
-    study_task_id = None
-    if chat_result["status"] == 0:
-        study_task_id = study(query_text)['task_id']
-    llm_search_result = await rag_chat_agent.llm_search(query_text)
-    decision = await rag_chat_agent.make_decision(chat_result, llm_search_result)
-    data = {
-        "result": decision["result"],
-        "status": decision["status"],
-        "relevance_score": decision["relevance_score"],
-    }
-    await chat_result_mapper.insert_chat_result(
-        query_text,
-        dataset_id_strs,
-        json.dumps(query_results, ensure_ascii=False),
-        chat_result["summary"],
-        chat_result["relevance_score"],
-        chat_result["status"],
-        llm_search_result["answer"],
-        llm_search_result["source"],
-        llm_search_result["status"],
-        decision["result"],
-        study_task_id
-    )
-
-    return data
+    spilt_query = await rag_chat_agent.split_query(query_text)
+    split_questions = spilt_query["split_questions"]
+    split_questions.append(query_text)
+
+    # 使用asyncio.gather并行处理每个问题
+    tasks = [
+        process_question(question, query_text, rag_chat_agent)
+        for question in split_questions
+    ]
+
+    # 等待所有任务完成并收集结果
+    data_list = await asyncio.gather(*tasks)
+    return data_list

+ 74 - 41
routes/blueprint.py

@@ -414,9 +414,9 @@ async def chat():
 
     rag_chat_agent = RAGChatAgent()
     chat_result = await rag_chat_agent.chat_with_deepseek(query_text, query_results)
-    study_task_id = None
-    if chat_result["status"] == 0:
-        study_task_id = study(query_text)['task_id']
+    # study_task_id = None
+    # if chat_result["status"] == 0:
+    #     study_task_id = study(query_text)['task_id']
     llm_search = await rag_chat_agent.llm_search(query_text)
     decision = await rag_chat_agent.make_decision(chat_result, llm_search)
     data = {
@@ -436,7 +436,6 @@ async def chat():
         llm_search["source"],
         llm_search["status"],
         decision["result"],
-        study_task_id,
         is_web=1,
     )
     return jsonify({"status_code": 200, "detail": "success", "data": data})
@@ -521,44 +520,78 @@ async def delete_task():
 async def rag_search():
     body = await request.get_json()
     query_text = body.get("queryText")
-    dataset_id_strs = "11,12"
-    dataset_ids = dataset_id_strs.split(",")
-    search_type = "hybrid"
-
-    query_results = await query_search(
-        query_text=query_text,
-        filters={"dataset_id": dataset_ids},
-        search_type=search_type,
-        limit=5,
-    )
-    resource = get_resource_manager()
-    chat_result_mapper = ChatResult(resource.mysql_client)
     rag_chat_agent = RAGChatAgent()
-    chat_result = await rag_chat_agent.chat_with_deepseek(query_text, query_results)
-    study_task_id = None
-    if chat_result["status"] == 0:
-        study_task_id = study(query_text)['task_id']
-    llm_search = await rag_chat_agent.llm_search(query_text)
-    decision = await rag_chat_agent.make_decision(chat_result, llm_search)
-    data = {
-        "result": decision["result"],
-        "status": decision["status"],
-        "relevance_score": decision["relevance_score"],
-    }
-    await chat_result_mapper.insert_chat_result(
-        query_text,
-        dataset_id_strs,
-        json.dumps(query_results, ensure_ascii=False),
-        chat_result["summary"],
-        chat_result["relevance_score"],
-        chat_result["status"],
-        llm_search["answer"],
-        llm_search["source"],
-        llm_search["status"],
-        decision["result"],
-        study_task_id
-    )
-    return jsonify({"status_code": 200, "detail": "success", "data": data})
+    spilt_query = await rag_chat_agent.split_query(query_text)
+    split_questions = spilt_query["split_questions"]
+    split_questions.append(query_text)
+
+    # 使用asyncio.gather并行处理每个问题
+    tasks = [
+        process_question(question, query_text, rag_chat_agent)
+        for question in split_questions
+    ]
+
+    # 等待所有任务完成并收集结果
+    data_list = await asyncio.gather(*tasks)
+    return jsonify({"status_code": 200, "detail": "success", "data": data_list})
+
+
+async def process_question(question, query_text, rag_chat_agent):
+    try:
+        dataset_id_strs = "11,12"
+        dataset_ids = dataset_id_strs.split(",")
+        search_type = "hybrid"
+
+        # 执行查询任务
+        query_results = await query_search(
+            query_text=question,
+            filters={"dataset_id": dataset_ids},
+            search_type=search_type,
+        )
+
+        resource = get_resource_manager()
+        chat_result_mapper = ChatResult(resource.mysql_client)
+
+        # 异步执行 chat 与 deepseek 的对话
+        chat_result = await rag_chat_agent.chat_with_deepseek(question, query_results)
+
+        # # 判断是否需要执行 study
+        study_task_id = None
+        if chat_result["status"] == 0:
+            study_task_id = study(question)["task_id"]
+
+        # 异步获取 LLM 搜索结果
+        llm_search_result = await rag_chat_agent.llm_search(question)
+
+        # 执行决策逻辑
+        decision = await rag_chat_agent.make_decision(chat_result, llm_search_result)
+
+        # 构建返回的数据
+        data = {
+            "query": question,
+            "result": decision["result"],
+            "status": decision["status"],
+            "relevance_score": decision["relevance_score"],
+        }
+
+        # 插入数据库
+        await chat_result_mapper.insert_chat_result(
+            question,
+            dataset_id_strs,
+            json.dumps(query_results, ensure_ascii=False),
+            chat_result["summary"],
+            chat_result["relevance_score"],
+            chat_result["status"],
+            llm_search_result["answer"],
+            llm_search_result["source"],
+            llm_search_result["status"],
+            decision["result"],
+            study_task_id,
+        )
+        return data
+    except Exception as e:
+        print(f"Error processing question: {question}. Error: {str(e)}")
+        return {"query": question, "error": str(e)}
 
 
 @server_bp.route("/chat/history", methods=["GET"])