jihuaqiang 1 semana atrás
pai
commit
3ddbefc298
5 arquivos alterados com 126 adições e 33 exclusões
  1. 119 24
      src/agent/query_agent.py
  2. 1 1
      src/api/main.py
  3. 5 6
      src/database/models.py
  4. 0 1
      src/models/schemas.py
  5. 1 1
      src/tools/scheduler.py

+ 119 - 24
src/agent/query_agent.py

@@ -17,9 +17,9 @@ class AgentState(TypedDict):
     task_id: int
     initial_queries: List[str]
     refined_queries: List[str]
+    result_queries: List[Dict[str, str]]
     context: str
     iteration_count: int
-    knowledgeType: str
 
 
 class QueryGenerationAgent:
@@ -54,6 +54,7 @@ class QueryGenerationAgent:
         workflow.add_node("generate_initial_queries", self._generate_initial_queries)
         workflow.add_node("refine_queries", self._refine_queries)
         workflow.add_node("validate_queries", self._validate_queries)
+        workflow.add_node("classify_queries", self._classify_queries)
         workflow.add_node("save_queries", self._save_queries)
         
         # 设置入口点
@@ -63,7 +64,8 @@ class QueryGenerationAgent:
         workflow.add_edge("analyze_question", "generate_initial_queries")
         workflow.add_edge("generate_initial_queries", "refine_queries")
         workflow.add_edge("refine_queries", "validate_queries")
-        workflow.add_edge("validate_queries", "save_queries")
+        workflow.add_edge("validate_queries", "classify_queries")
+        workflow.add_edge("classify_queries", "save_queries")
         workflow.add_edge("save_queries", END)
         
         return workflow.compile()
@@ -168,6 +170,14 @@ class QueryGenerationAgent:
         logger.info(f"查询词验证结果: {validated_queries}")
         state["refined_queries"] = validated_queries
         return state
+
+    def _classify_queries(self, state: AgentState) -> AgentState:
+        """推测每个查询词的知识类型并写入result_queries"""
+        refined_queries = state.get("refined_queries", [])
+        # 使用大模型进行分类
+        result_items: List[Dict[str, str]] = self._classify_with_llm(refined_queries)
+        state["result_queries"] = result_items
+        return state
     
     def _save_queries(self, state: AgentState) -> AgentState:
         """保存查询词到外部接口节点"""
@@ -178,26 +188,25 @@ class QueryGenerationAgent:
             logger.warning("没有查询词需要保存")
             return state
         
-        # 调用外部接口保存查询词
+        # 调用外部接口保存查询词(按类型分组)
         try:
             url = "http://aigc-testapi.cybertogether.net/aigc/agent/knowledgeWorkflow/addQuery"
             headers = {"Content-Type": "application/json"}
-            
-            # 根据问题内容判断知识类型,这里可以根据实际需求调整逻辑
-            knowledge_type = state["knowledgeType"]  # 默认类型,可以根据问题内容动态判断
-            
-            data = {
-                "knowledgeType": knowledge_type,
-                "queryWords": refined_queries
-            }
-            
-            # 使用httpx发送请求
-            with httpx.Client() as client:
-                response = client.post(url, headers=headers, json=data, timeout=30)
-                response.raise_for_status()
-                
-            logger.info(f"查询词保存成功: {refined_queries}")
-            
+
+            # 仅使用前一步的分类结果,不做即时分类
+            result_items: List[Dict[str, str]] = state.get("result_queries", [])
+            if not result_items:
+                logger.warning("缺少分类结果result_queries,跳过外部提交")
+                return state
+
+            if result_items:
+                with httpx.Client() as client:
+                    data_content = {"queryWords": result_items}
+                    resp1 = client.post(url, headers=headers, json=data_content, timeout=30)
+                    resp1.raise_for_status()
+
+            logger.info(f"查询词保存成功: question={question},query数量={len(result_items)}")
+
         except httpx.HTTPError as e:
             logger.error(f"保存查询词时发生HTTP错误: {str(e)}")
         except Exception as e:
@@ -205,14 +214,100 @@ class QueryGenerationAgent:
         
         return state
     
-    async def generate_queries(self, question: str, task_id: int = 0, knowledgeType: str = "") -> List[str]:
+    def _infer_knowledge_type(self, query: str) -> str:
+        """根据查询词简单推断知识类型(内容知识/工具知识)"""
+        tool_keywords = [
+            "安装", "配置", "使用", "教程", "API", "SDK", "命令", "指令", "版本",
+            "错误", "异常", "调试", "部署", "集成", "调用", "参数", "示例", "代码",
+            "CLI", "tool", "library", "framework"
+        ]
+        lower_q = query.lower()
+        for kw in tool_keywords:
+            if kw.lower() in lower_q:
+                return "工具知识"
+        return "内容知识"
+
+    def _classify_with_llm(self, queries: List[str]) -> List[Dict[str, str]]:
+        """调用LLM将查询词分类为 内容知识 / 工具知识。
+
+        返回形如 [{"query": q, "knowledgeType": "内容知识"|"工具知识"}, ...]
+        若解析失败,降级为将所有查询标记为 内容知识(不使用关键词启发)。
+        """
+        if not queries:
+            return []
+
+        instruction = (
+            "你是一名分类助手。请将下面的查询词逐一分类为‘内容知识’或‘工具知识’。\n"
+            "请只返回严格的JSON数组,每个元素为对象:{\"query\": 原始查询词, \"knowledgeType\": \"内容知识\" 或 \"工具知识\"}。\n"
+            "不要输出任何解释或多余文本。"
+        )
+        payload = "\n".join(queries)
+
+        prompt = ChatPromptTemplate.from_messages([
+            SystemMessage(content=instruction),
+            HumanMessage(content=f"查询词列表(每行一个):\n{payload}")
+        ])
+
+        try:
+            response = self.llm.invoke(prompt.format_messages())
+            text = (response.content or "").strip()
+            logger.info(f"LLM分类结果: {text}")
+            # 尝试解析为JSON数组;若失败,尝试从代码块或文本中提取
+            try:
+                data = json.loads(text)
+            except Exception:
+                data = self._extract_json_array_from_text(text)
+            result: List[Dict[str, str]] = []
+            for item in data:
+                q = str(item.get("query", "")).strip()
+                kt = str(item.get("knowledgeType", "")).strip()
+                if q and kt in ("内容知识", "工具知识"):
+                    result.append({"query": q, "knowledgeType": kt})
+            # 保证顺序与输入一致,且都包含
+            if len(result) != len(queries):
+                # 尝试基于输入进行对齐
+                mapped = {it["query"]: it["knowledgeType"] for it in result}
+                aligned: List[Dict[str, str]] = []
+                for q in queries:
+                    kt = mapped.get(q, "内容知识")
+                    aligned.append({"query": q, "knowledgeType": kt})
+                return aligned
+            return result
+        except Exception as e:
+            # 降级:全部标注为内容知识(不做关键词匹配)
+            logger.warning(f"LLM分类失败,使用降级策略: {e}")
+            return [{"query": q, "knowledgeType": "内容知识"} for q in queries]
+
+    def _extract_json_array_from_text(self, text: str) -> List[Dict[str, Any]]:
+        """尽力从模型输出(可能包含```json代码块或多余文本)中提取JSON数组。"""
+        s = (text or "").strip()
+        # 去除三引号包裹的代码块
+        if s.startswith("```"):
+            # 去掉第一行的 ``` 或 ```json
+            first_newline = s.find('\n')
+            if first_newline != -1:
+                s = s[first_newline + 1:]
+            if s.endswith("```"):
+                s = s[:-3]
+            s = s.strip()
+        # 在文本中查找首个JSON数组
+        import re
+        match = re.search(r"\[[\s\S]*\]", s)
+        if not match:
+            raise ValueError("未找到JSON数组片段")
+        json_str = match.group(0)
+        data = json.loads(json_str)
+        if not isinstance(data, list):
+            raise ValueError("提取内容不是JSON数组")
+        return data
+
+    async def generate_queries(self, question: str, task_id: int = 0) -> List[str]:
         """
         生成查询词的主入口
         
         Args:
             question: 用户问题
             task_id: 任务ID
-            knowledgeType: 知识类型
         Returns:
             生成的查询词列表
         """
@@ -221,14 +316,14 @@ class QueryGenerationAgent:
             "task_id": task_id,
             "initial_queries": [],
             "refined_queries": [],
+            "result_queries": [],
             "context": "",
-            "iteration_count": 0,
-            "knowledgeType": knowledgeType
+            "iteration_count": 0
         }
         
         try:
             result = await self.graph.ainvoke(initial_state)
-            return result["refined_queries"]
+            return result["result_queries"]
         except Exception as e:
             # 更新任务状态为失败
             if task_id > 0:

+ 1 - 1
src/api/main.py

@@ -144,7 +144,7 @@ async def generate_queries(request: QuestionRequest):
         task_id = int(time.time() * 1000)
         
         # 创建任务记录,状态设置为0(待执行)
-        task_dao.create_task(task_id, request.question, request.knowledgeType)
+        task_dao.create_task(task_id, request.question)
         logger.info(f"创建任务: {task_id},状态: 待执行")
         
         # 立即返回待执行状态

+ 5 - 6
src/database/models.py

@@ -69,7 +69,7 @@ class QueryTaskDAO:
     def __init__(self):
         self.db_manager = get_db_manager()
     
-    def create_task(self, task_id: int, question: str, knowledgeType: str) -> bool:
+    def create_task(self, task_id: int, question: str) -> bool:
         """
         创建新的查询任务
         
@@ -83,15 +83,14 @@ class QueryTaskDAO:
         try:
             with self.db_manager.get_cursor() as cursor:
                 sql = """
-                INSERT INTO knowledge_suggest_query (task_id, question, status, knowledgeType)
-                VALUES (%s, %s, %s, %s)
+                INSERT INTO knowledge_suggest_query (task_id, question, status)
+                VALUES (%s, %s, %s)
                 ON DUPLICATE KEY UPDATE
                 question = VALUES(question),
                 status = VALUES(status),
-                querys = NULL,
-                knowledgeType = VALUES(knowledgeType)
+                querys = NULL
                 """
-                cursor.execute(sql, (task_id, question, QueryTaskStatus.PENDING, knowledgeType))
+                cursor.execute(sql, (task_id, question, QueryTaskStatus.PENDING))
                 return True
         except Exception as e:
             logger.error(f"创建任务失败: {e}")

+ 0 - 1
src/models/schemas.py

@@ -5,7 +5,6 @@ from pydantic import BaseModel, Field
 class QuestionRequest(BaseModel):
     """问题请求模型"""
     question: str = Field(..., description="用户提出的问题", min_length=1, max_length=1000)
-    knowledgeType: str = Field(..., description="知识类型")
 
 
 class QueryResponse(BaseModel):

+ 1 - 1
src/tools/scheduler.py

@@ -81,7 +81,7 @@ class TaskScheduler:
             
             try:
                 # 使用Agent生成查询词
-                queries = await self.agent.generate_queries(task.question, task.task_id, task.knowledgeType)
+                queries = await self.agent.generate_queries(task.question, task.task_id)
                 
                 # 更新任务结果
                 success = self.task_dao.update_task_results(task.task_id, queries, QueryTaskStatus.SUCCESS)