Procházet zdrojové kódy

非工具类型处理

jihuaqiang před 1 týdnem
rodič
revize
598607728e

+ 84 - 6
src/agent/query_agent.py

@@ -46,21 +46,80 @@ class QueryGenerationAgent:
         """创建LangGraph状态图"""
         workflow = StateGraph(AgentState)
         
-        # 添加节点(仅保留 生成 与 保存)
+        # 添加节点:分类 -> 生成 -> 保存
+        workflow.add_node("classify_question", self._classify_question)
         workflow.add_node("generate_initial_queries", self._generate_initial_queries)
         workflow.add_node("save_queries", self._save_queries)
         
         # 设置入口点
-        workflow.set_entry_point("generate_initial_queries")
+        workflow.set_entry_point("classify_question")
         
-        # 添加边
+        # 添加条件边:内容直接结束,工具进入生成
+        try:
+            # 优先使用条件路由(若LangGraph版本支持)
+            workflow.add_conditional_edges(
+                "classify_question",
+                self._route_after_classify,
+                {
+                    "TOOL": "generate_initial_queries",
+                    "CONTENT": END
+                }
+            )
+        except Exception:
+            # 兼容:不支持条件边时,继续到生成节点,由生成节点自行拦截
+            workflow.add_edge("classify_question", "generate_initial_queries")
         workflow.add_edge("generate_initial_queries", "save_queries")
         workflow.add_edge("save_queries", END)
         
         return workflow.compile()
+
+    def _classify_question(self, state: AgentState) -> AgentState:
+        """判断问题知识类型:工具知识 / 内容知识"""
+        question = state.get("question", "")
+        instruction = (
+            "你是一个分类助手。请根据以下标准判断问题类型并只输出结果:\n"
+            "- 工具知识:涉及软件/工具/编程/API/SDK/命令/安装/配置/使用/部署/调试/版本/参数/代码/集成/CLI 等操作与实现。\n"
+            "- 内容知识:话题洞察、趋势、创作灵感、正文内容、案例分析、概念解释、非工具操作的问题。\n"
+            "要求:严格只输出两个词之一——工具知识 或 内容知识;不要输出任何其它字符、解释或标点。"
+        )
+        prompt = ChatPromptTemplate.from_messages([
+            SystemMessage(content=instruction),
+            HumanMessage(content=question)
+        ])
+        try:
+            response = self.llm.invoke(prompt.format_messages())
+            text = (response.content or "").strip()
+            logger.info(f"问题类型判断结果: {text}")
+            kt = "工具知识" if "工具" in text else "内容知识"
+            state["knowledgeType"] = kt
+            # 若为内容,直接将任务标记为失败并准备结束
+            if kt != "工具知识":
+                try:
+                    if state.get("task_id", 0) > 0:
+                        self.task_dao.mark_task_failed(state["task_id"], "暂不支持非工具内容")
+                except Exception:
+                    pass
+                state["result_queries"] = []
+        except Exception:
+            # 失败默认判为内容知识以避免误触发
+            state["knowledgeType"] = "内容知识"
+            try:
+                if state.get("task_id", 0) > 0:
+                    self.task_dao.mark_task_failed(state["task_id"], "暂不支持非工具内容")
+            except Exception:
+                pass
+            state["result_queries"] = []
+        return state
+
+    def _route_after_classify(self, state: AgentState) -> str:
+        """根据分类结果路由:工具 -> TOOL;内容 -> CONTENT"""
+        return "TOOL" if state.get("knowledgeType") == "工具知识" else "CONTENT"
     
     def _generate_initial_queries(self, state: AgentState) -> AgentState:
         """生成 refined_queries(从结构化JSON中聚合三类关键词)"""
+        # 若为内容类型,直接抛出以在上层处理
+        if state.get("knowledgeType") != "工具知识":
+            raise ValueError("不支持内容类型问题")
         question = state["question"]
         # 使用新的结构化系统提示
         prompt = ChatPromptTemplate.from_messages([
@@ -113,9 +172,9 @@ class QueryGenerationAgent:
             logger.warning("没有查询词需要保存")
             return state
         
-        # 合并 knowledgeType 与每个查询词,形成提交数据
+        # 合并 knowledgeType 与每个查询词,附加 task_id,形成提交数据
         result_items: List[Dict[str, str]] = [
-            {"query": q, "knowledgeType": knowledge_type} for q in refined_queries
+            {"query": q, "knowledgeType": knowledge_type, "task_id": state.get("task_id", 0)} for q in refined_queries
         ]
         state["result_queries"] = result_items
         
@@ -239,7 +298,7 @@ class QueryGenerationAgent:
             "initial_queries": [],
             "refined_queries": [],
             "result_queries": [],
-            "knowledgeType": knowledge_type or "内容知识"
+            "knowledgeType": "工具知识"
         }
         
         try:
@@ -252,3 +311,22 @@ class QueryGenerationAgent:
             # 降级处理:返回原始问题
             return [question]
 
+    def is_tool_question(self, question: str) -> bool:
+        """同步判断问题是否为工具知识类型。"""
+        instruction = (
+            "你是一个分类助手。请根据以下标准判断问题类型并只输出结果:\n"
+            "- 工具知识:涉及软件/工具/编程/API/SDK/命令/安装/配置/使用/部署/调试/版本/参数/代码/集成/CLI 等操作与实现。\n"
+            "- 内容知识:话题洞察、趋势、创作灵感、正文内容、案例分析、概念解释、非工具操作的问题。\n"
+            "要求:严格只输出两个词之一——工具知识 或 内容知识;不要输出任何其它字符、解释或标点。"
+        )
+        prompt = ChatPromptTemplate.from_messages([
+            SystemMessage(content=instruction),
+            HumanMessage(content=question)
+        ])
+        try:
+            response = self.llm.invoke(prompt.format_messages())
+            text = (response.content or "").strip()
+            return "工具" in text
+        except Exception:
+            return False
+

+ 1 - 1
src/api/main.py

@@ -144,7 +144,7 @@ async def generate_queries(request: QuestionRequest):
         task_id = int(time.time() * 1000)
         
         # 创建任务记录,状态设置为0(待执行)
-        task_dao.create_task(task_id, request.question, knowledge_type=request.knowledgeType or "内容知识")
+        task_dao.create_task(task_id, request.question, knowledge_type="工具知识")
         logger.info(f"创建任务: {task_id},状态: 待执行")
         
         # 立即返回待执行状态

+ 30 - 7
src/database/models.py

@@ -18,7 +18,7 @@ class QueryTaskStatus:
 class KnowledgeSuggestQuery:
     """知识查询建议模型"""
     
-    def __init__(self, task_id: int, question: str, querys: Optional[List[str]] = None, status: int = QueryTaskStatus.PENDING, knowledgeType: str = ""):
+    def __init__(self, task_id: int, question: str, querys: Optional[List[str]] = None, status: int = QueryTaskStatus.PENDING, knowledgeType: str = "", err_msg: str = ""):
         """
         初始化查询任务
         
@@ -33,6 +33,7 @@ class KnowledgeSuggestQuery:
         self.querys = querys or []
         self.status = status
         self.knowledgeType = knowledgeType
+        self.err_msg = err_msg or ""
     
     def to_dict(self) -> Dict[str, Any]:
         """转换为字典"""
@@ -41,7 +42,8 @@ class KnowledgeSuggestQuery:
             'question': self.question,
             'querys': json.dumps(self.querys, ensure_ascii=False) if self.querys else None,
             'status': self.status,
-            'knowledgeType': self.knowledgeType
+            'knowledgeType': self.knowledgeType,
+            'err_msg': self.err_msg or None
         }
     
     @classmethod
@@ -59,7 +61,8 @@ class KnowledgeSuggestQuery:
             question=data['question'],
             querys=querys,
             status=data['status'],
-            knowledgeType=data['knowledgeType']
+            knowledgeType=data.get('knowledgeType', ""),
+            err_msg=data.get('err_msg', "")
         )
 
 
@@ -83,15 +86,16 @@ class QueryTaskDAO:
         try:
             with self.db_manager.get_cursor() as cursor:
                 sql = """
-                INSERT INTO knowledge_suggest_query (task_id, question, status, knowledgeType)
-                VALUES (%s, %s, %s, %s)
+                INSERT INTO knowledge_suggest_query (task_id, question, status, knowledgeType, err_msg)
+                VALUES (%s, %s, %s, %s, %s)
                 ON DUPLICATE KEY UPDATE
                 question = VALUES(question),
                 status = VALUES(status),
                 querys = NULL,
-                knowledgeType = VALUES(knowledgeType)
+                knowledgeType = VALUES(knowledgeType),
+                err_msg = NULL
                 """
-                cursor.execute(sql, (task_id, question, QueryTaskStatus.PENDING, knowledge_type or "内容知识"))
+                cursor.execute(sql, (task_id, question, QueryTaskStatus.PENDING, knowledge_type or "内容知识", None))
                 return True
         except Exception as e:
             logger.error(f"创建任务失败: {e}")
@@ -116,6 +120,25 @@ class QueryTaskDAO:
         except Exception as e:
             logger.error(f"更新任务状态失败: {e}")
             return False
+
+    def mark_task_failed(self, task_id: int, err_msg: str) -> bool:
+        """
+        将任务标记为失败并记录错误信息
+        """
+        try:
+            with self.db_manager.get_cursor() as cursor:
+                try:
+                    sql = "UPDATE knowledge_suggest_query SET status = %s, err_msg = %s, knowledgeType = %s WHERE task_id = %s"
+                    cursor.execute(sql, (QueryTaskStatus.FAILED, err_msg, "内容知识", task_id))
+                    return cursor.rowcount > 0
+                except Exception:
+                    # 回退到仅更新状态
+                    sql = "UPDATE knowledge_suggest_query SET status = %s WHERE task_id = %s"
+                    cursor.execute(sql, (QueryTaskStatus.FAILED, task_id))
+                    return cursor.rowcount > 0
+        except Exception as e:
+            logger.error(f"标记任务失败时出错: {e}")
+            return False
     
     def update_task_results(self, task_id: int, querys: List[str], status: int = QueryTaskStatus.SUCCESS) -> bool:
         """

+ 1 - 2
src/models/schemas.py

@@ -1,11 +1,10 @@
-from typing import List, Optional
+from typing import List
 from pydantic import BaseModel, Field
 
 
 class QuestionRequest(BaseModel):
     """问题请求模型"""
     question: str = Field(..., description="用户提出的问题", min_length=1, max_length=1000)
-    knowledgeType: Optional[str] = Field(default="内容知识", description="知识类型:内容知识/工具知识")
 
 
 class QueryResponse(BaseModel):

+ 9 - 2
src/tools/scheduler.py

@@ -80,8 +80,15 @@ class TaskScheduler:
             self.task_dao.update_task_status(task.task_id, QueryTaskStatus.RUNNING)
             
             try:
-                # 使用Agent生成查询词,传入knowledgeType
-                queries = await self.agent.generate_queries(task.question, task.task_id, getattr(task, 'knowledgeType', '') or '')
+                # 使用Agent生成查询词(Agent内部会在内容类型时直接失败并返回空)
+                queries = await self.agent.generate_queries(task.question, task.task_id)
+                
+                # 若为空,视为不支持内容类型,标记失败
+                if not queries:
+                    # 非工具问题:标记失败并写入错误信息
+                    self.task_dao.mark_task_failed(task.task_id, "暂不支持非工具内容")
+                    logger.info(f"任务 {task.task_id} 非工具问题,标记为失败并写入错误信息")
+                    return
                 
                 # 更新任务结果
                 success = self.task_dao.update_task_results(task.task_id, queries, QueryTaskStatus.SUCCESS)