jihuaqiang 1 week ago
parent
commit
96bd8793bf
3 changed files with 140 additions and 70 deletions
  1. 32 8
      agent.py
  2. 1 1
      agents/clean_agent/agent.py
  3. 107 61
      agents/expand_agent/agent.py

+ 32 - 8
agent.py

@@ -21,7 +21,7 @@ from fastapi.responses import JSONResponse
 from pydantic import BaseModel, Field
 import uvicorn
 from agents.clean_agent.agent import execute_agent_with_api
-from agents.expand_agent.agent import execute_expand_agent_with_api
+from agents.expand_agent.agent import execute_expand_agent_with_api, _update_expansion_status
 
 # LangGraph 相关导入
 try:
@@ -61,6 +61,9 @@ class TriggerResponse(BaseModel):
     success: int
     details: List[Dict[str, Any]]
 
+class ExpandRequest(BaseModel):
+    requestId: str = Field(..., description="扩展查询请求ID")
+
 # 全局变量
 identify_tool = None
 
@@ -86,6 +89,16 @@ def update_request_status(request_id: str, status: int):
     except Exception as e:
         logger.error(f"更新请求状态异常: requestId={request_id}, status={status}, error={e}")
 
+def _update_expansion_status(requestId: str, status: int):
+    """更新扩展查询状态"""
+    try:
+        from utils.mysql_db import MysqlHelper
+        sql = "UPDATE knowledge_request SET expansion_status = %s WHERE request_id = %s"
+        MysqlHelper.update_values(sql, (status, requestId))
+        logger.info(f"更新扩展查询状态成功: requestId={requestId}, status={status}")
+    except Exception as e:
+        logger.error(f"更新扩展查询状态失败: requestId={requestId}, status={status}, error={e}")
+
 @asynccontextmanager
 async def lifespan(app: FastAPI):
     """应用生命周期管理"""
@@ -589,21 +602,32 @@ async def extract(input: str):
         raise HTTPException(status_code=500, detail=f"执行Agent时出错: {str(e)}")
 
 @app.post("/expand")
-async def expand(requestId: str):
+async def expand(request: ExpandRequest, background_tasks: BackgroundTasks):
     """
-    执行Agent处理用户指令
+    执行扩展查询处理
     
     Args:
-        requestId: 请求ID
+        request: 包含请求ID的请求体
+        background_tasks: FastAPI 后台任务
         
     Returns:
-        dict: 包含执行结果的字典
+        dict: 包含执行状态的字典
     """
     try:
-        result = execute_expand_agent_with_api(requestId)
-        return {"status": 1, "result": result}
+        requestId = request.requestId
+        
+        # 立即更新状态为处理中
+        _update_expansion_status(requestId, 1)
+        
+        # 添加后台任务
+        background_tasks.add_task(execute_expand_agent_with_api, requestId)
+        
+        # 立即返回状态
+        return {"status": 1, "requestId": requestId, "message": "扩展查询处理已启动"}
+        
     except Exception as e:
-        raise HTTPException(status_code=500, detail=f"执行Agent时出错: {str(e)}")
+        logger.error(f"启动扩展查询处理失败: requestId={requestId}, error={e}")
+        raise HTTPException(status_code=500, detail=f"启动扩展查询处理时出错: {str(e)}")
 
 if __name__ == "__main__":
     # 启动服务

+ 1 - 1
agents/clean_agent/agent.py

@@ -5,7 +5,7 @@ from langgraph.graph.message import add_messages
 import os
 from langchain.chat_models import init_chat_model
 from IPython.display import Image, display
-from tools import evaluation_extraction_tool
+from .tools import evaluation_extraction_tool
 
 from langgraph.prebuilt import ToolNode, tools_condition
 from langgraph.checkpoint.memory import InMemorySaver

+ 107 - 61
agents/expand_agent/agent.py

@@ -34,55 +34,41 @@ def _fetch_parsing_data_by_request(request_id: str) -> List[str]:
 
 
 def _build_prompt(data_samples: str, input_query: str) -> str:
-    """构建用于扩展查询的 Prompt"""    
-    return f"""# 角色与目标
-你是一个Query词扩展与优化专家,任务是从输入文本中提取适合用于搜索"创作/制作方法论"的关键词,并基于这些关键词生成多组可直接用于爬虫搜索的query词模板。
-目标是:
-1. 提取文本中与"创作方法论/制作方法论"相关的核心关键词。
-2. 为每个关键词生成同义词、近义词、衍生表达。
-3. 将关键词与固定的模板组合,输出可单独搜索的query词,以及多关键词组合的query词。
-4. 输出格式要求结构化,支持工程化直接使用。
-
-
-# 任务
-1. 提取文本中的核心关键词(与"创作/制作方法论"高度相关)。
-2. 为每个关键词生成同义/扩展词。
-3. 按以下模板格式输出结果。
-4. 将任务中1和2的关键词(含扩展的关键词)需要填入到输出格式的 "query_templates"与"query_combinations":模板后输出结果。
-
-# 输出格式
-{{
-  "core_keywords": [
-    "关键词1",
-    "关键词2",
+    """构建用于扩展查询的 Prompt,使用 expansion.md 模板"""
+    
+    # 读取 expansion.md 模板文件
+    try:
+        template_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "prompt", "expansion.md")
+        with open(template_path, 'r', encoding='utf-8') as f:
+            template = f.read()
+        
+        # 替换模板中的占位符
+        prompt = template.replace("{Original_Query}", input_query or "无")
+        prompt = prompt.replace("{Content_For_Expansion}", data_samples)
+        
+        return prompt
+        
+    except Exception as e:
+        logger.error(f"读取 expansion.md 模板失败: {e}")
+        # 如果模板读取失败,使用备用 prompt
+        return f"""你是一位顶级的知识库专家,精通语义分析、信息检索和搜索优化策略。
+
+根据以下原始查询和内容,生成扩展的查询词列表:
+
+**原始查询:** {input_query or "无"}
+
+**扩展要基于的内容:** {data_samples}
+
+请生成一个JSON数组形式的扩展查询词列表,确保每个查询词都具体、可操作、信息量丰富。
+
+输出格式:
+```json
+[
+    "查询词1",
+    "查询词2",
     ...
-  ],
-  "expanded_keywords": {{
-    "关键词1": ["同义词1", "同义词2", "变体1"],
-    "关键词2": ["同义词1", "同义词2", "变体1"]
-  }},
-  "query_templates": [
-    "如何 + {{关键词}}",
-    "{{关键词}} + 方法论",
-    "{{关键词}} + 技巧",
-    "{{关键词}} + 步骤",
-    "{{关键词}} + 流程",
-    "{{关键词}} + 案例",
-    "{{关键词}} + 总结",
-    "{{关键词}} + 原理",
-    "如何提升 + {{关键词}}",
-    "从零开始 + {{关键词}}"
-  ],
-  "query_combinations": [
-    "{{关键词1}} + {{关键词2}} + 方法论",
-    "{{关键词1}} + {{关键词2}} + 案例",
-    "如何结合 {{关键词1}} 与 {{关键词2}}"
-  ]
-}}
-
-输入文本: {data_samples}
-
-请按照上述格式输出JSON结果。"""
+]
+```"""
 
 
 def _run_llm(prompt: str) -> List[str]:
@@ -94,16 +80,43 @@ def _run_llm(prompt: str) -> List[str]:
         processor = GeminiProcessor()
         result = processor.process(content=prompt, system_prompt="你是专业的查询扩展助手")
         print(f"result: {result}")
-        # 尝试解析返回结果
+        
+        # 处理返回结果
         if isinstance(result, dict):
+            # 如果有错误,直接返回空列表
+            if "error" in result:
+                logger.error(f"Gemini API 返回错误: {result['error']}")
+                return []
+            # 如果结果在 result 字段中
             text = result.get("result", "") or result.get("raw_response", "")
         else:
             text = str(result)
-            
+        
+        # 清理文本,移除 markdown 代码块标记
+        if "```json" in text:
+            # 提取 ```json 和 ``` 之间的内容
+            start = text.find("```json") + 7
+            end = text.find("```", start)
+            if end != -1:
+                text = text[start:end].strip()
+        elif "```" in text:
+            # 提取 ``` 之间的内容
+            start = text.find("```") + 3
+            end = text.find("```", start)
+            if end != -1:
+                text = text[start:end].strip()
+        
+        # 尝试解析 JSON
         try:
             queries = json.loads(text)
-            return queries if isinstance(queries, list) else []
-        except:
+            if isinstance(queries, list):
+                # 确保所有元素都是字符串
+                return [str(q) for q in queries if q]
+            else:
+                logger.warning(f"Gemini 返回的不是列表格式: {type(queries)}")
+                return []
+        except json.JSONDecodeError as e:
+            logger.error(f"JSON 解析失败: {e}, 原始文本: {text}")
             return []
             
     except Exception as e:
@@ -125,28 +138,61 @@ def _heuristic_expand(input_query: str) -> List[str]:
         f"{base} 高级技巧"
     ]
 
+def _update_expansion_status(requestId: str, status: int):
+    sql = "UPDATE knowledge_request SET expansion_status = %s WHERE request_id = %s"
+    MysqlHelper.update_values(sql, (status, requestId))
 
 def execute_expand_agent_with_api(requestId: str, query: str = "") -> Dict[str, Any]:
     """对外暴露的API:根据requestId查询数据,生成扩展查询"""
     # 获取数据
     data_samples = _fetch_parsing_data_by_request(requestId)
-    print(f"data_samples: {data_samples[0]}")
     
     # 构建prompt
     prompt = _build_prompt(data_samples[0], query)
-    print(f"prompt: {prompt}")
 
     # 生成扩展查询
     expanded = _run_llm(prompt)
     if not expanded:
         expanded = _heuristic_expand(query)
-    print(f"expanded: {expanded}")
-    return {
-        "requestId": requestId,
-        "inputQuery": query,
-        "prompt": prompt,
-        "expandedQueries": expanded
-    } 
+    
+    # 将扩展查询结果插入到 knowledge_expand_content 表
+    try:
+        # 先检查是否已存在记录
+        check_sql = "SELECT id FROM knowledge_expand_content WHERE request_id = %s LIMIT 1"
+        existing_record = MysqlHelper.get_values(check_sql, (requestId,))
+        
+        # 将 expanded 列表转换为 JSON 字符串
+        expand_querys_json = json.dumps(expanded, ensure_ascii=False)
+        print(f"expand_querys_json: {expand_querys_json}")
+        
+        if existing_record:
+            # 记录已存在,执行更新
+            update_sql = """
+            UPDATE knowledge_expand_content 
+            SET expand_querys = %s, query = %s, create_time = NOW()
+            WHERE request_id = %s
+            """
+            affected_rows = MysqlHelper.update_values(update_sql, (expand_querys_json, query, requestId))
+            logger.info(f"扩展查询结果已更新: requestId={requestId}, affected_rows={affected_rows}")
+        else:
+            # 记录不存在,执行插入
+            insert_sql = """
+            INSERT INTO knowledge_expand_content 
+            (request_id, create_time, expand_querys, query) 
+            VALUES (%s, NOW(), %s, %s)
+            """
+            insert_result = MysqlHelper.insert_and_get_id(insert_sql, (requestId, expand_querys_json, query))
+            logger.info(f"扩展查询结果已插入: requestId={requestId}, insert_id={insert_result}")
+        
+        # 更新状态为处理完成
+        _update_expansion_status(requestId, 2)
+            
+    except Exception as e:
+        logger.error(f"保存扩展查询结果到数据库时出错: requestId={requestId}, error={e}")
+        _update_expansion_status(requestId, 3)
+    
+    return {"status": 1, "requestId": requestId}
+
 
 if __name__ == "__main__":
     queries = execute_expand_agent_with_api("REQUEST_001")