Kaynağa Gözat

增加query_type字段

jihuaqiang 4 ay önce
ebeveyn
işleme
a9f63fdbac
3 değiştirilmiş dosya ile 9 ekleme ve 9 silme
  1. 4 4
      src/agent/query_agent.py
  2. 3 3
      src/database/models.py
  3. 2 2
      src/tools/scheduler.py

+ 4 - 4
src/agent/query_agent.py

@@ -469,7 +469,7 @@ class QueryGenerationAgent:
             raise ValueError("提取内容不是JSON数组")
         return data
 
-    async def generate_queries(self, question: str, need_store: int = 1, task_id: int = 0, knowledge_type: str = "") -> List[str]:
+    async def generate_queries(self, question: str, need_store: int = 1, task_id: int = 0, knowledge_type: str = "") -> tuple[List[str], str]:
         """
         生成查询词的主入口
         
@@ -478,7 +478,7 @@ class QueryGenerationAgent:
             task_id: 任务ID
             knowledge_type: 知识类型(可选,用于兼容)
         Returns:
-            生成的查询词列表
+            元组:(生成的查询词列表, 问题类型)
         """
         initial_state = {
             "question": question,
@@ -493,14 +493,14 @@ class QueryGenerationAgent:
         
         try:
             result = await self.graph.ainvoke(initial_state)
-            return result["result_queries"]
+            return result["result_queries"], result["query_type"]
         except Exception as e:
             logger.error(f"生成查询词失败: {e}")
             # 更新任务状态为失败
             if task_id > 0:
                 self.task_dao.update_task_status(task_id, QueryTaskStatus.FAILED)
             # 降级处理:返回原始问题
-            return [question]
+            return [question], "How"  # 默认返回How类型
 
     def is_tool_question(self, question: str) -> bool:
         """同步判断问题是否为工具知识类型。"""

+ 3 - 3
src/database/models.py

@@ -147,7 +147,7 @@ class QueryTaskDAO:
             logger.error(f"标记任务失败时出错: {e}")
             return False
     
-    def update_task_results(self, task_id: int, querys: List[str], status: int = QueryTaskStatus.SUCCESS) -> bool:
+    def update_task_results(self, task_id: int, querys: List[str], query_type: str, status: int = QueryTaskStatus.SUCCESS) -> bool:
         """
         更新任务结果
         
@@ -161,9 +161,9 @@ class QueryTaskDAO:
         """
         try:
             with self.db_manager.get_cursor() as cursor:
-                sql = "UPDATE knowledge_suggest_query SET querys = %s, status = %s WHERE task_id = %s"
+                sql = "UPDATE knowledge_suggest_query SET querys = %s, status = %s, query_type = %s WHERE task_id = %s"
                 querys_json = json.dumps(querys, ensure_ascii=False)
-                cursor.execute(sql, (querys_json, status, task_id))
+                cursor.execute(sql, (querys_json, status, query_type, task_id))
                 return cursor.rowcount > 0
         except Exception as e:
             logger.error(f"更新任务结果失败: {e}")

+ 2 - 2
src/tools/scheduler.py

@@ -81,7 +81,7 @@ class TaskScheduler:
             
             try:
                 # 使用Agent生成查询词(Agent内部会在内容类型时直接失败并返回空)
-                queries = await self.agent.generate_queries(task.question, task.need_store, task.task_id)
+                queries, query_type = await self.agent.generate_queries(task.question, task.need_store, task.task_id)
                 
                 # 若为空,视为不支持内容类型,标记失败
                 if not queries:
@@ -90,7 +90,7 @@ class TaskScheduler:
                     return
                 
                 # 更新任务结果
-                success = self.task_dao.update_task_results(task.task_id, queries, QueryTaskStatus.SUCCESS)
+                success = self.task_dao.update_task_results(task.task_id, queries, query_type, QueryTaskStatus.SUCCESS)
                 
                 if success:
                     logger.info(f"任务 {task.task_id} 处理成功,生成 {len(queries)} 个查询词")