|
@@ -17,9 +17,9 @@ class AgentState(TypedDict):
|
|
|
task_id: int
|
|
|
initial_queries: List[str]
|
|
|
refined_queries: List[str]
|
|
|
+ result_queries: List[Dict[str, str]]
|
|
|
context: str
|
|
|
iteration_count: int
|
|
|
- knowledgeType: str
|
|
|
|
|
|
|
|
|
class QueryGenerationAgent:
|
|
@@ -54,6 +54,7 @@ class QueryGenerationAgent:
|
|
|
workflow.add_node("generate_initial_queries", self._generate_initial_queries)
|
|
|
workflow.add_node("refine_queries", self._refine_queries)
|
|
|
workflow.add_node("validate_queries", self._validate_queries)
|
|
|
+ workflow.add_node("classify_queries", self._classify_queries)
|
|
|
workflow.add_node("save_queries", self._save_queries)
|
|
|
|
|
|
# 设置入口点
|
|
@@ -63,7 +64,8 @@ class QueryGenerationAgent:
|
|
|
workflow.add_edge("analyze_question", "generate_initial_queries")
|
|
|
workflow.add_edge("generate_initial_queries", "refine_queries")
|
|
|
workflow.add_edge("refine_queries", "validate_queries")
|
|
|
- workflow.add_edge("validate_queries", "save_queries")
|
|
|
+ workflow.add_edge("validate_queries", "classify_queries")
|
|
|
+ workflow.add_edge("classify_queries", "save_queries")
|
|
|
workflow.add_edge("save_queries", END)
|
|
|
|
|
|
return workflow.compile()
|
|
@@ -168,6 +170,14 @@ class QueryGenerationAgent:
|
|
|
logger.info(f"查询词验证结果: {validated_queries}")
|
|
|
state["refined_queries"] = validated_queries
|
|
|
return state
|
|
|
+
|
|
|
+ def _classify_queries(self, state: AgentState) -> AgentState:
|
|
|
+ """推测每个查询词的知识类型并写入result_queries"""
|
|
|
+ refined_queries = state.get("refined_queries", [])
|
|
|
+ # 使用大模型进行分类
|
|
|
+ result_items: List[Dict[str, str]] = self._classify_with_llm(refined_queries)
|
|
|
+ state["result_queries"] = result_items
|
|
|
+ return state
|
|
|
|
|
|
def _save_queries(self, state: AgentState) -> AgentState:
|
|
|
"""保存查询词到外部接口节点"""
|
|
@@ -178,26 +188,25 @@ class QueryGenerationAgent:
|
|
|
logger.warning("没有查询词需要保存")
|
|
|
return state
|
|
|
|
|
|
- # 调用外部接口保存查询词
|
|
|
+ # 调用外部接口保存查询词(按类型分组)
|
|
|
try:
|
|
|
url = "http://aigc-testapi.cybertogether.net/aigc/agent/knowledgeWorkflow/addQuery"
|
|
|
headers = {"Content-Type": "application/json"}
|
|
|
-
|
|
|
- # 根据问题内容判断知识类型,这里可以根据实际需求调整逻辑
|
|
|
- knowledge_type = state["knowledgeType"] # 默认类型,可以根据问题内容动态判断
|
|
|
-
|
|
|
- data = {
|
|
|
- "knowledgeType": knowledge_type,
|
|
|
- "queryWords": refined_queries
|
|
|
- }
|
|
|
-
|
|
|
- # 使用httpx发送请求
|
|
|
- with httpx.Client() as client:
|
|
|
- response = client.post(url, headers=headers, json=data, timeout=30)
|
|
|
- response.raise_for_status()
|
|
|
-
|
|
|
- logger.info(f"查询词保存成功: {refined_queries}")
|
|
|
-
|
|
|
+
|
|
|
+ # 仅使用前一步的分类结果,不做即时分类
|
|
|
+ result_items: List[Dict[str, str]] = state.get("result_queries", [])
|
|
|
+ if not result_items:
|
|
|
+ logger.warning("缺少分类结果result_queries,跳过外部提交")
|
|
|
+ return state
|
|
|
+
|
|
|
+ if result_items:
|
|
|
+ with httpx.Client() as client:
|
|
|
+ data_content = {"queryWords": result_items}
|
|
|
+ resp1 = client.post(url, headers=headers, json=data_content, timeout=30)
|
|
|
+ resp1.raise_for_status()
|
|
|
+
|
|
|
+ logger.info(f"查询词保存成功: question={question},query数量={len(result_items)}")
|
|
|
+
|
|
|
except httpx.HTTPError as e:
|
|
|
logger.error(f"保存查询词时发生HTTP错误: {str(e)}")
|
|
|
except Exception as e:
|
|
@@ -205,14 +214,100 @@ class QueryGenerationAgent:
|
|
|
|
|
|
return state
|
|
|
|
|
|
- async def generate_queries(self, question: str, task_id: int = 0, knowledgeType: str = "") -> List[str]:
|
|
|
+ def _infer_knowledge_type(self, query: str) -> str:
|
|
|
+ """根据查询词简单推断知识类型(内容知识/工具知识)"""
|
|
|
+ tool_keywords = [
|
|
|
+ "安装", "配置", "使用", "教程", "API", "SDK", "命令", "指令", "版本",
|
|
|
+ "错误", "异常", "调试", "部署", "集成", "调用", "参数", "示例", "代码",
|
|
|
+ "CLI", "tool", "library", "framework"
|
|
|
+ ]
|
|
|
+ lower_q = query.lower()
|
|
|
+ for kw in tool_keywords:
|
|
|
+ if kw.lower() in lower_q:
|
|
|
+ return "工具知识"
|
|
|
+ return "内容知识"
|
|
|
+
|
|
|
+ def _classify_with_llm(self, queries: List[str]) -> List[Dict[str, str]]:
|
|
|
+ """调用LLM将查询词分类为 内容知识 / 工具知识。
|
|
|
+
|
|
|
+ 返回形如 [{"query": q, "knowledgeType": "内容知识"|"工具知识"}, ...]
|
|
|
+ 若解析失败,降级为将所有查询标记为 内容知识(不使用关键词启发)。
|
|
|
+ """
|
|
|
+ if not queries:
|
|
|
+ return []
|
|
|
+
|
|
|
+ instruction = (
|
|
|
+ "你是一名分类助手。请将下面的查询词逐一分类为‘内容知识’或‘工具知识’。\n"
|
|
|
+ "请只返回严格的JSON数组,每个元素为对象:{\"query\": 原始查询词, \"knowledgeType\": \"内容知识\" 或 \"工具知识\"}。\n"
|
|
|
+ "不要输出任何解释或多余文本。"
|
|
|
+ )
|
|
|
+ payload = "\n".join(queries)
|
|
|
+
|
|
|
+ prompt = ChatPromptTemplate.from_messages([
|
|
|
+ SystemMessage(content=instruction),
|
|
|
+ HumanMessage(content=f"查询词列表(每行一个):\n{payload}")
|
|
|
+ ])
|
|
|
+
|
|
|
+ try:
|
|
|
+ response = self.llm.invoke(prompt.format_messages())
|
|
|
+ text = (response.content or "").strip()
|
|
|
+ logger.info(f"LLM分类结果: {text}")
|
|
|
+ # 尝试解析为JSON数组;若失败,尝试从代码块或文本中提取
|
|
|
+ try:
|
|
|
+ data = json.loads(text)
|
|
|
+ except Exception:
|
|
|
+ data = self._extract_json_array_from_text(text)
|
|
|
+ result: List[Dict[str, str]] = []
|
|
|
+ for item in data:
|
|
|
+ q = str(item.get("query", "")).strip()
|
|
|
+ kt = str(item.get("knowledgeType", "")).strip()
|
|
|
+ if q and kt in ("内容知识", "工具知识"):
|
|
|
+ result.append({"query": q, "knowledgeType": kt})
|
|
|
+ # 保证顺序与输入一致,且都包含
|
|
|
+ if len(result) != len(queries):
|
|
|
+ # 尝试基于输入进行对齐
|
|
|
+ mapped = {it["query"]: it["knowledgeType"] for it in result}
|
|
|
+ aligned: List[Dict[str, str]] = []
|
|
|
+ for q in queries:
|
|
|
+ kt = mapped.get(q, "内容知识")
|
|
|
+ aligned.append({"query": q, "knowledgeType": kt})
|
|
|
+ return aligned
|
|
|
+ return result
|
|
|
+ except Exception as e:
|
|
|
+ # 降级:全部标注为内容知识(不做关键词匹配)
|
|
|
+ logger.warning(f"LLM分类失败,使用降级策略: {e}")
|
|
|
+ return [{"query": q, "knowledgeType": "内容知识"} for q in queries]
|
|
|
+
|
|
|
+ def _extract_json_array_from_text(self, text: str) -> List[Dict[str, Any]]:
|
|
|
+ """尽力从模型输出(可能包含```json代码块或多余文本)中提取JSON数组。"""
|
|
|
+ s = (text or "").strip()
|
|
|
+ # 去除三引号包裹的代码块
|
|
|
+ if s.startswith("```"):
|
|
|
+ # 去掉第一行的 ``` 或 ```json
|
|
|
+ first_newline = s.find('\n')
|
|
|
+ if first_newline != -1:
|
|
|
+ s = s[first_newline + 1:]
|
|
|
+ if s.endswith("```"):
|
|
|
+ s = s[:-3]
|
|
|
+ s = s.strip()
|
|
|
+ # 在文本中查找首个JSON数组
|
|
|
+ import re
|
|
|
+ match = re.search(r"\[[\s\S]*\]", s)
|
|
|
+ if not match:
|
|
|
+ raise ValueError("未找到JSON数组片段")
|
|
|
+ json_str = match.group(0)
|
|
|
+ data = json.loads(json_str)
|
|
|
+ if not isinstance(data, list):
|
|
|
+ raise ValueError("提取内容不是JSON数组")
|
|
|
+ return data
|
|
|
+
|
|
|
+ async def generate_queries(self, question: str, task_id: int = 0) -> List[str]:
|
|
|
"""
|
|
|
生成查询词的主入口
|
|
|
|
|
|
Args:
|
|
|
question: 用户问题
|
|
|
task_id: 任务ID
|
|
|
- knowledgeType: 知识类型
|
|
|
Returns:
|
|
|
生成的查询词列表
|
|
|
"""
|
|
@@ -221,14 +316,14 @@ class QueryGenerationAgent:
|
|
|
"task_id": task_id,
|
|
|
"initial_queries": [],
|
|
|
"refined_queries": [],
|
|
|
+ "result_queries": [],
|
|
|
"context": "",
|
|
|
- "iteration_count": 0,
|
|
|
- "knowledgeType": knowledgeType
|
|
|
+ "iteration_count": 0
|
|
|
}
|
|
|
|
|
|
try:
|
|
|
result = await self.graph.ainvoke(initial_state)
|
|
|
- return result["refined_queries"]
|
|
|
+ return result["result_queries"]
|
|
|
except Exception as e:
|
|
|
# 更新任务状态为失败
|
|
|
if task_id > 0:
|