Kaynağa Gözat

支持工具使用 知识query词生成

jihuaqiang 5 ay önce
ebeveyn
işleme
be6e817334

+ 47 - 13
src/agent/query_agent.py

@@ -8,6 +8,7 @@ import json
 
 from ..tools.prompts import (
     STRUCTURED_TOOL_DEMAND_PROMPT,
+    TOOL_USAGE_PROMPT,
     CLASSIFICATION_PROMPT,
     QUERY_CLASSIFICATION_PROMPT,
     WHAT_CLASSIFICATION_PROMPT,
@@ -21,7 +22,6 @@ class AgentState(TypedDict):
     question: str
     task_id: int
     need_store: int
-    initial_queries: List[str]
     refined_queries: List[str]
     result_queries: List[Dict[str, str]]
     knowledgeType: str
@@ -69,6 +69,7 @@ class QueryGenerationAgent:
         # 添加节点
         workflow.add_node("classify_question", self._classify_question)
         workflow.add_node("generate_tool_queries", self._generate_tool_queries)  # 工具类型查询生成
+        workflow.add_node("generate_tool_usage_queries", self._generate_tool_usage_queries)  # 工具使用类型查询生成
         workflow.add_node("classify_content_dimension", self._classify_content_dimension)  # 内容维度分类
         workflow.add_node("expand_content_queries", self._expand_content_queries)  # 内容查询扩展
         workflow.add_node("save_queries", self._save_queries)
@@ -76,13 +77,14 @@ class QueryGenerationAgent:
         # 设置入口点
         workflow.set_entry_point("classify_question")
         
-        # 条件路由:工具知识 vs 内容知识
+        # 条件路由:工具知识 vs 工具使用 vs 内容知识
         try:
             workflow.add_conditional_edges(
                 "classify_question",
                 self._route_after_classify,
                 {
                     "TOOL": "generate_tool_queries",
+                    "TOOL_USAGE": "generate_tool_usage_queries",
                     "CONTENT": "classify_content_dimension"
                 }
             )
@@ -92,6 +94,9 @@ class QueryGenerationAgent:
         # 工具类型:生成 -> 保存 -> 结束
         workflow.add_edge("generate_tool_queries", "save_queries")
         
+        # 工具使用类型:生成 -> 保存 -> 结束
+        workflow.add_edge("generate_tool_usage_queries", "save_queries")
+        
         # 内容类型:分类维度 -> 条件路由
         try:
             workflow.add_conditional_edges(
@@ -112,8 +117,15 @@ class QueryGenerationAgent:
         return workflow.compile()
 
     def _classify_question(self, state: AgentState) -> AgentState:
-        """判断问题知识类型:工具知识 / 内容知识"""
+        """判断问题知识类型:工具知识 / 工具使用 / 内容知识"""
         question = state.get("question", "")
+        
+        print(f"knowledgeType: {state.get('knowledgeType')}")
+        # 如果已经设置了 knowledgeType 且为"工具使用",直接使用
+        if state.get("knowledgeType") == "工具使用":
+            logger.info(f"问题类型已设置为: 工具使用")
+            return state
+        
         instruction = (
             "你是一个分类助手。请根据以下标准判断问题类型并只输出结果:\n"
             "- 工具知识:涉及软件/工具/编程/API/SDK/命令/安装/配置/使用/部署/调试/版本/参数/代码/集成/CLI 等操作与实现。\n"
@@ -137,8 +149,14 @@ class QueryGenerationAgent:
         return state
 
     def _route_after_classify(self, state: AgentState) -> str:
-        """根据分类结果路由:工具 -> TOOL;内容 -> CONTENT"""
-        return "TOOL" if state.get("knowledgeType") == "工具知识" else "CONTENT"
+        """根据分类结果路由:工具知识 -> TOOL;工具使用 -> TOOL_USAGE;内容知识 -> CONTENT"""
+        knowledge_type = state.get("knowledgeType", "")
+        if knowledge_type == "工具使用":
+            return "TOOL_USAGE"
+        elif knowledge_type == "工具知识":
+            return "TOOL"
+        else:
+            return "CONTENT"
     
     def _generate_tool_queries(self, state: AgentState) -> AgentState:
         """生成工具类型的查询词(从结构化JSON中聚合三类关键词)"""
@@ -174,11 +192,32 @@ class QueryGenerationAgent:
                 if q not in seen:
                     seen.add(q)
                     deduped.append(q)
-            state["initial_queries"] = deduped
             state["refined_queries"] = deduped
         except Exception as e:
             logger.warning(f"结构化需求解析失败,降级为原始问题: {e}")
-            state["initial_queries"] = [question]
+            state["refined_queries"] = [question]
+        return state
+    
+    def _generate_tool_usage_queries(self, state: AgentState) -> AgentState:
+        """生成工具使用类型的查询词"""
+        question = state["question"]
+        print(f"工具使用类型查询词: {question}")
+        prompt = ChatPromptTemplate.from_messages([
+            SystemMessage(content=TOOL_USAGE_PROMPT),
+            HumanMessage(content=question)
+        ])
+        try:
+            response = self.llm.invoke(prompt.format_messages())
+            text = (response.content or "").strip()
+            logger.info(f"工具使用类型查询词结果: {text}")
+            # 解析JSON结果
+            try:
+                data = json.loads(text)
+            except Exception:
+                data = self._extract_json_from_text(text)
+            state["refined_queries"] = data.get("queries", [])
+        except Exception as e:
+            logger.warning(f"工具使用类型查询词失败,降级为原始问题: {e}")
             state["refined_queries"] = [question]
         return state
     
@@ -276,7 +315,6 @@ class QueryGenerationAgent:
                     if state.get("task_id", 0) > 0:
                         self.task_dao.mark_task_failed(state["task_id"], error_msg)
                     state["result_queries"] = []
-                    state["initial_queries"] = []
                     state["refined_queries"] = []
                     return state
                 
@@ -302,7 +340,6 @@ class QueryGenerationAgent:
                 if state.get("task_id", 0) > 0:
                     self.task_dao.mark_task_failed(state["task_id"], error_msg)
                 state["result_queries"] = []
-                state["initial_queries"] = []
                 state["refined_queries"] = []
                 return state
             
@@ -314,12 +351,10 @@ class QueryGenerationAgent:
                     seen.add(q)
                     deduped.append(q)
             
-            state["initial_queries"] = deduped
             state["refined_queries"] = deduped
             
         except Exception as e:
             logger.warning(f"查询扩展失败,降级为原始问题: {e}")
-            state["initial_queries"] = [question]
             state["refined_queries"] = [question]
         
         return state
@@ -484,10 +519,9 @@ class QueryGenerationAgent:
             "question": question,
             "task_id": task_id,
             "need_store": need_store,
-            "initial_queries": [],
             "refined_queries": [],
             "result_queries": [],
-            "knowledgeType": "",
+            "knowledgeType": knowledge_type,
             "query_type": ""
         }
         

+ 1 - 1
src/api/main.py

@@ -144,7 +144,7 @@ async def generate_queries(request: QuestionRequest):
         task_id = int(time.time() * 1000)
         
         # 创建任务记录,状态设置为0(待执行)
-        task_dao.create_task(task_id, request.question, knowledge_type="工具知识", need_store=request.need_store)
+        task_dao.create_task(task_id, request.question, knowledge_type=request.knowledge_type or "", need_store=request.need_store)
         logger.info(f"创建任务: {task_id},状态: 待执行")
         
         # 立即返回待执行状态

+ 1 - 1
src/database/models.py

@@ -102,7 +102,7 @@ class QueryTaskDAO:
                 err_msg = NULL,
                 need_store = VALUES(need_store)
                 """
-                cursor.execute(sql, (task_id, question, QueryTaskStatus.PENDING, knowledge_type or "内容知识", None, need_store))
+                cursor.execute(sql, (task_id, question, QueryTaskStatus.PENDING, knowledge_type, None, need_store))
                 return True
         except Exception as e:
             logger.error(f"创建任务失败: {e}")

+ 1 - 0
src/models/schemas.py

@@ -5,6 +5,7 @@ from pydantic import BaseModel, Field
 class QuestionRequest(BaseModel):
     """问题请求模型"""
     question: str = Field(..., description="用户提出的问题", min_length=1, max_length=1000)
+    knowledge_type: str = Field(default="", description="知识类型")
     need_store: int = Field(default=1, description="是否存储查询词")
 
 

+ 56 - 0
src/tools/prompts.py

@@ -42,6 +42,62 @@ how_to_use_queries: 操作/解决问题类关键词列表。
 请严格只输出JSON数组,不要包含任何额外文本或解释。
 """
 
+TOOL_USAGE_PROMPT = """
+## 核心职责
+你的核心职责是:根据用户提供的「工具名称」与「功能需求」,**自动生成一系列多样化、覆盖全网的精准搜索查询词(Query)**,用于爬虫组件搜集该工具在此功能下的使用知识、案例与Prompt示例。
+---
+## 输入格式
+输入形式为自然语言,「工具名称」与「功能需求」
+例如:
+> 用 nano banana 实现切割出图片的主体人物 
+> 用 美图秀秀 9宫格拼图排版 
+---
+## 处理逻辑与生成规则
+### ① 关键词拆解与联想
+基于工具名与功能需求,生成核心关键词集合。 
+**注意:不要将功能词拆得过细,以免失去功能本意。**
+输出字段:
+- **tool_names**:工具名称及常见变体 
+- **function_core_zh**:功能核心词(中文,保持原意,如“抠图”“提高清晰度”“生成真人图像”) 
+- **search_descriptors**:描述性关键词,如 
+ `["教程", "使用方法", "步骤", "案例", "参数", "技巧", "Prompt""输入输出要求"]`
+注意:最终的输出字段格式为**tool_names**+**function_core_zh**+**search_descriptors**
+其中**search_descriptors**:需要根据具体工具有灵活性的改变,不局限于我给的示例
+例如在遇到需要精准参数精准Prompt影响工具使用效果时,如"midjourney",你的工具核心query词可能为"midjourney生成真人图像参数",或者
+"midjourney生成真人图像Prompt"。当遇到非参数控制的工具,比如:"新红数据",你的工具核心query词可能为"新红榜单查询教程""新红宠物话题排行榜使用方法"等。
+---
+### ② Query生成策略
+Query应更口语化、场景化。Query可以涉及详细教程,侧重于教学与实操。 
+生成的Query应指向能提取以下信息:
+- 具体使用步骤与方法 
+- 实际用户案例(Before & After) 
+- Prompt示例与参数设置 
+- 输入输出格式要求 
+**Query生成规则:**
+- 组合结构为: 
+ `tool_name + function_core_zh + search_descriptors`
+- 示例:
+ - “nano banana 抠图教程”
+ - “nano banana 图像主体提取方法”
+ - “nano banana 人物切割实操步骤”
+ - “nano banana 抠图Prompt分享”
+生成3-5条最能精准表达工具+功能查询的Query即可。
+---
+## 输出格式
+以JSON格式输出:
+```json
+{
+ "tool_name": "string",
+ "function_demand": "string",
+ "generated_keywords": {
+    "tool_names": ["string", ...],
+    "function_core_zh": ["string", ...],
+    "search_descriptors": ["string", ...]
+ },
+ "queries": ["string", ...],
+}
+"""
+
 CLASSIFICATION_PROMPT = """
 # 任务说明
 你是一个“内容创作拆解的 Query 分类器”。  

+ 1 - 1
src/tools/scheduler.py

@@ -81,7 +81,7 @@ class TaskScheduler:
             
             try:
                 # 使用Agent生成查询词(Agent内部会在内容类型时直接失败并返回空)
-                queries, knowledge_type, query_type = await self.agent.generate_queries(task.question, task.need_store, task.task_id)
+                queries, knowledge_type, query_type = await self.agent.generate_queries(task.question, task.need_store, task.task_id, task.knowledgeType)
                 
                 # 若为空,视为不支持内容类型,标记失败
                 if not queries: