|
|
@@ -8,6 +8,7 @@ import json
|
|
|
|
|
|
from ..tools.prompts import (
|
|
|
STRUCTURED_TOOL_DEMAND_PROMPT,
|
|
|
+ TOOL_USAGE_PROMPT,
|
|
|
CLASSIFICATION_PROMPT,
|
|
|
QUERY_CLASSIFICATION_PROMPT,
|
|
|
WHAT_CLASSIFICATION_PROMPT,
|
|
|
@@ -21,7 +22,6 @@ class AgentState(TypedDict):
|
|
|
question: str
|
|
|
task_id: int
|
|
|
need_store: int
|
|
|
- initial_queries: List[str]
|
|
|
refined_queries: List[str]
|
|
|
result_queries: List[Dict[str, str]]
|
|
|
knowledgeType: str
|
|
|
@@ -69,6 +69,7 @@ class QueryGenerationAgent:
|
|
|
# 添加节点
|
|
|
workflow.add_node("classify_question", self._classify_question)
|
|
|
workflow.add_node("generate_tool_queries", self._generate_tool_queries) # 工具类型查询生成
|
|
|
+ workflow.add_node("generate_tool_usage_queries", self._generate_tool_usage_queries) # 工具使用类型查询生成
|
|
|
workflow.add_node("classify_content_dimension", self._classify_content_dimension) # 内容维度分类
|
|
|
workflow.add_node("expand_content_queries", self._expand_content_queries) # 内容查询扩展
|
|
|
workflow.add_node("save_queries", self._save_queries)
|
|
|
@@ -76,13 +77,14 @@ class QueryGenerationAgent:
|
|
|
# 设置入口点
|
|
|
workflow.set_entry_point("classify_question")
|
|
|
|
|
|
- # 条件路由:工具知识 vs 内容知识
|
|
|
+ # 条件路由:工具知识 vs 工具使用 vs 内容知识
|
|
|
try:
|
|
|
workflow.add_conditional_edges(
|
|
|
"classify_question",
|
|
|
self._route_after_classify,
|
|
|
{
|
|
|
"TOOL": "generate_tool_queries",
|
|
|
+ "TOOL_USAGE": "generate_tool_usage_queries",
|
|
|
"CONTENT": "classify_content_dimension"
|
|
|
}
|
|
|
)
|
|
|
@@ -92,6 +94,9 @@ class QueryGenerationAgent:
|
|
|
# 工具类型:生成 -> 保存 -> 结束
|
|
|
workflow.add_edge("generate_tool_queries", "save_queries")
|
|
|
|
|
|
+ # 工具使用类型:生成 -> 保存 -> 结束
|
|
|
+ workflow.add_edge("generate_tool_usage_queries", "save_queries")
|
|
|
+
|
|
|
# 内容类型:分类维度 -> 条件路由
|
|
|
try:
|
|
|
workflow.add_conditional_edges(
|
|
|
@@ -112,8 +117,15 @@ class QueryGenerationAgent:
|
|
|
return workflow.compile()
|
|
|
|
|
|
def _classify_question(self, state: AgentState) -> AgentState:
|
|
|
- """判断问题知识类型:工具知识 / 内容知识"""
|
|
|
+ """判断问题知识类型:工具知识 / 工具使用 / 内容知识"""
|
|
|
question = state.get("question", "")
|
|
|
+
|
|
|
+ print(f"knowledgeType: {state.get('knowledgeType')}")
|
|
|
+ # 如果已经设置了 knowledgeType 且为"工具使用",直接使用
|
|
|
+ if state.get("knowledgeType") == "工具使用":
|
|
|
+ logger.info(f"问题类型已设置为: 工具使用")
|
|
|
+ return state
|
|
|
+
|
|
|
instruction = (
|
|
|
"你是一个分类助手。请根据以下标准判断问题类型并只输出结果:\n"
|
|
|
"- 工具知识:涉及软件/工具/编程/API/SDK/命令/安装/配置/使用/部署/调试/版本/参数/代码/集成/CLI 等操作与实现。\n"
|
|
|
@@ -137,8 +149,14 @@ class QueryGenerationAgent:
|
|
|
return state
|
|
|
|
|
|
def _route_after_classify(self, state: AgentState) -> str:
|
|
|
- """根据分类结果路由:工具 -> TOOL;内容 -> CONTENT"""
|
|
|
- return "TOOL" if state.get("knowledgeType") == "工具知识" else "CONTENT"
|
|
|
+ """根据分类结果路由:工具知识 -> TOOL;工具使用 -> TOOL_USAGE;内容知识 -> CONTENT"""
|
|
|
+ knowledge_type = state.get("knowledgeType", "")
|
|
|
+ if knowledge_type == "工具使用":
|
|
|
+ return "TOOL_USAGE"
|
|
|
+ elif knowledge_type == "工具知识":
|
|
|
+ return "TOOL"
|
|
|
+ else:
|
|
|
+ return "CONTENT"
|
|
|
|
|
|
def _generate_tool_queries(self, state: AgentState) -> AgentState:
|
|
|
"""生成工具类型的查询词(从结构化JSON中聚合三类关键词)"""
|
|
|
@@ -174,11 +192,32 @@ class QueryGenerationAgent:
|
|
|
if q not in seen:
|
|
|
seen.add(q)
|
|
|
deduped.append(q)
|
|
|
- state["initial_queries"] = deduped
|
|
|
state["refined_queries"] = deduped
|
|
|
except Exception as e:
|
|
|
logger.warning(f"结构化需求解析失败,降级为原始问题: {e}")
|
|
|
- state["initial_queries"] = [question]
|
|
|
+ state["refined_queries"] = [question]
|
|
|
+ return state
|
|
|
+
|
|
|
+ def _generate_tool_usage_queries(self, state: AgentState) -> AgentState:
|
|
|
+ """生成工具使用类型的查询词"""
|
|
|
+ question = state["question"]
|
|
|
+ print(f"工具使用类型查询词: {question}")
|
|
|
+ prompt = ChatPromptTemplate.from_messages([
|
|
|
+ SystemMessage(content=TOOL_USAGE_PROMPT),
|
|
|
+ HumanMessage(content=question)
|
|
|
+ ])
|
|
|
+ try:
|
|
|
+ response = self.llm.invoke(prompt.format_messages())
|
|
|
+ text = (response.content or "").strip()
|
|
|
+ logger.info(f"工具使用类型查询词结果: {text}")
|
|
|
+ # 解析JSON结果
|
|
|
+ try:
|
|
|
+ data = json.loads(text)
|
|
|
+ except Exception:
|
|
|
+ data = self._extract_json_from_text(text)
|
|
|
+ state["refined_queries"] = data.get("queries", [])
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"工具使用类型查询词失败,降级为原始问题: {e}")
|
|
|
state["refined_queries"] = [question]
|
|
|
return state
|
|
|
|
|
|
@@ -276,7 +315,6 @@ class QueryGenerationAgent:
|
|
|
if state.get("task_id", 0) > 0:
|
|
|
self.task_dao.mark_task_failed(state["task_id"], error_msg)
|
|
|
state["result_queries"] = []
|
|
|
- state["initial_queries"] = []
|
|
|
state["refined_queries"] = []
|
|
|
return state
|
|
|
|
|
|
@@ -302,7 +340,6 @@ class QueryGenerationAgent:
|
|
|
if state.get("task_id", 0) > 0:
|
|
|
self.task_dao.mark_task_failed(state["task_id"], error_msg)
|
|
|
state["result_queries"] = []
|
|
|
- state["initial_queries"] = []
|
|
|
state["refined_queries"] = []
|
|
|
return state
|
|
|
|
|
|
@@ -314,12 +351,10 @@ class QueryGenerationAgent:
|
|
|
seen.add(q)
|
|
|
deduped.append(q)
|
|
|
|
|
|
- state["initial_queries"] = deduped
|
|
|
state["refined_queries"] = deduped
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.warning(f"查询扩展失败,降级为原始问题: {e}")
|
|
|
- state["initial_queries"] = [question]
|
|
|
state["refined_queries"] = [question]
|
|
|
|
|
|
return state
|
|
|
@@ -484,10 +519,9 @@ class QueryGenerationAgent:
|
|
|
"question": question,
|
|
|
"task_id": task_id,
|
|
|
"need_store": need_store,
|
|
|
- "initial_queries": [],
|
|
|
"refined_queries": [],
|
|
|
"result_queries": [],
|
|
|
- "knowledgeType": "",
|
|
|
+ "knowledgeType": knowledge_type,
|
|
|
"query_type": ""
|
|
|
}
|
|
|
|