|
|
@@ -6,7 +6,11 @@ from langchain.schema import HumanMessage, SystemMessage
|
|
|
import httpx
|
|
|
import json
|
|
|
|
|
|
-from ..tools.prompts import STRUCTURED_TOOL_DEMAND_PROMPT
|
|
|
+from ..tools.prompts import (
|
|
|
+ STRUCTURED_TOOL_DEMAND_PROMPT,
|
|
|
+ CLASSIFICATION_PROMPT,
|
|
|
+ QUERY_CLASSIFICATION_PROMPT
|
|
|
+)
|
|
|
from ..database.models import QueryTaskDAO, QueryTaskStatus, logger
|
|
|
|
|
|
|
|
|
@@ -18,6 +22,8 @@ class AgentState(TypedDict):
|
|
|
refined_queries: List[str]
|
|
|
result_queries: List[Dict[str, str]]
|
|
|
knowledgeType: str
|
|
|
+ content_dimension: str # 内容类型的维度: How / What / Pattern
|
|
|
+ is_query_type: bool # 是否为可处理的查询类型
|
|
|
|
|
|
|
|
|
class QueryGenerationAgent:
|
|
|
@@ -46,29 +52,47 @@ class QueryGenerationAgent:
|
|
|
"""创建LangGraph状态图"""
|
|
|
workflow = StateGraph(AgentState)
|
|
|
|
|
|
- # 添加节点:分类 -> 生成 -> 保存
|
|
|
+ # 添加节点
|
|
|
workflow.add_node("classify_question", self._classify_question)
|
|
|
- workflow.add_node("generate_initial_queries", self._generate_initial_queries)
|
|
|
+ workflow.add_node("generate_tool_queries", self._generate_tool_queries) # 工具类型查询生成
|
|
|
+ workflow.add_node("classify_content_dimension", self._classify_content_dimension) # 内容维度分类
|
|
|
+ workflow.add_node("expand_content_queries", self._expand_content_queries) # 内容查询扩展
|
|
|
workflow.add_node("save_queries", self._save_queries)
|
|
|
|
|
|
# 设置入口点
|
|
|
workflow.set_entry_point("classify_question")
|
|
|
|
|
|
- # 添加条件边:内容直接结束,工具进入生成
|
|
|
+ # 条件路由:工具知识 vs 内容知识
|
|
|
try:
|
|
|
- # 优先使用条件路由(若LangGraph版本支持)
|
|
|
workflow.add_conditional_edges(
|
|
|
"classify_question",
|
|
|
self._route_after_classify,
|
|
|
{
|
|
|
- "TOOL": "generate_initial_queries",
|
|
|
- "CONTENT": END
|
|
|
+ "TOOL": "generate_tool_queries",
|
|
|
+ "CONTENT": "classify_content_dimension"
|
|
|
}
|
|
|
)
|
|
|
except Exception:
|
|
|
- # 兼容:不支持条件边时,继续到生成节点,由生成节点自行拦截
|
|
|
- workflow.add_edge("classify_question", "generate_initial_queries")
|
|
|
- workflow.add_edge("generate_initial_queries", "save_queries")
|
|
|
+ workflow.add_edge("classify_question", "generate_tool_queries")
|
|
|
+
|
|
|
+ # 工具类型:生成 -> 保存 -> 结束
|
|
|
+ workflow.add_edge("generate_tool_queries", "save_queries")
|
|
|
+
|
|
|
+ # 内容类型:分类维度 -> 条件路由
|
|
|
+ try:
|
|
|
+ workflow.add_conditional_edges(
|
|
|
+ "classify_content_dimension",
|
|
|
+ self._route_after_content_classify,
|
|
|
+ {
|
|
|
+ "EXPAND": "expand_content_queries",
|
|
|
+ "UNSUPPORTED": END
|
|
|
+ }
|
|
|
+ )
|
|
|
+ except Exception:
|
|
|
+ workflow.add_edge("classify_content_dimension", "expand_content_queries")
|
|
|
+
|
|
|
+ # 内容扩展:扩展 -> 保存 -> 结束
|
|
|
+ workflow.add_edge("expand_content_queries", "save_queries")
|
|
|
workflow.add_edge("save_queries", END)
|
|
|
|
|
|
return workflow.compile()
|
|
|
@@ -92,34 +116,18 @@ class QueryGenerationAgent:
|
|
|
logger.info(f"问题类型判断结果: {text}")
|
|
|
kt = "工具知识" if "工具" in text else "内容知识"
|
|
|
state["knowledgeType"] = kt
|
|
|
- # 若为内容,直接将任务标记为失败并准备结束
|
|
|
- if kt != "工具知识":
|
|
|
- try:
|
|
|
- if state.get("task_id", 0) > 0:
|
|
|
- self.task_dao.mark_task_failed(state["task_id"], "暂不支持非工具内容")
|
|
|
- except Exception:
|
|
|
- pass
|
|
|
- state["result_queries"] = []
|
|
|
- except Exception:
|
|
|
+ except Exception as e:
|
|
|
# 失败默认判为内容知识以避免误触发
|
|
|
+ logger.warning(f"问题类型判断失败: {e}")
|
|
|
state["knowledgeType"] = "内容知识"
|
|
|
- try:
|
|
|
- if state.get("task_id", 0) > 0:
|
|
|
- self.task_dao.mark_task_failed(state["task_id"], "暂不支持非工具内容")
|
|
|
- except Exception:
|
|
|
- pass
|
|
|
- state["result_queries"] = []
|
|
|
return state
|
|
|
|
|
|
def _route_after_classify(self, state: AgentState) -> str:
|
|
|
"""根据分类结果路由:工具 -> TOOL;内容 -> CONTENT"""
|
|
|
return "TOOL" if state.get("knowledgeType") == "工具知识" else "CONTENT"
|
|
|
|
|
|
- def _generate_initial_queries(self, state: AgentState) -> AgentState:
|
|
|
- """生成 refined_queries(从结构化JSON中聚合三类关键词)"""
|
|
|
- # 若为内容类型,直接抛出以在上层处理
|
|
|
- if state.get("knowledgeType") != "工具知识":
|
|
|
- raise ValueError("不支持内容类型问题")
|
|
|
+ def _generate_tool_queries(self, state: AgentState) -> AgentState:
|
|
|
+ """生成工具类型的查询词(从结构化JSON中聚合三类关键词)"""
|
|
|
question = state["question"]
|
|
|
# 使用新的结构化系统提示
|
|
|
prompt = ChatPromptTemplate.from_messages([
|
|
|
@@ -160,7 +168,107 @@ class QueryGenerationAgent:
|
|
|
state["refined_queries"] = [question]
|
|
|
return state
|
|
|
|
|
|
- # 删除 refine/validate/classify 节点
|
|
|
+ def _classify_content_dimension(self, state: AgentState) -> AgentState:
|
|
|
+ """使用CLASSIFICATION_PROMPT对内容类型问题进行维度分类(How/What/Pattern)"""
|
|
|
+ question = state["question"]
|
|
|
+ prompt = ChatPromptTemplate.from_messages([
|
|
|
+ SystemMessage(content=CLASSIFICATION_PROMPT),
|
|
|
+ HumanMessage(content=question)
|
|
|
+ ])
|
|
|
+ try:
|
|
|
+ response = self.llm.invoke(prompt.format_messages())
|
|
|
+ text = (response.content or "").strip()
|
|
|
+ logger.info(f"内容维度分类结果: {text}")
|
|
|
+
|
|
|
+ # 解析JSON结果
|
|
|
+ try:
|
|
|
+ data = json.loads(text)
|
|
|
+ except Exception:
|
|
|
+ data = self._extract_json_from_text(text)
|
|
|
+
|
|
|
+ dimension = data.get("所属维度", "").strip()
|
|
|
+ state["content_dimension"] = dimension
|
|
|
+
|
|
|
+ # 判断是否为可处理的查询类型(目前仅支持How类型)
|
|
|
+ state["is_query_type"] = dimension == "How"
|
|
|
+
|
|
|
+ if not state["is_query_type"]:
|
|
|
+ # 不支持的类型,标记任务失败
|
|
|
+ error_msg = f"暂不支持{dimension}类型的内容问题,当前仅支持How类型"
|
|
|
+ logger.info(error_msg)
|
|
|
+ if state.get("task_id", 0) > 0:
|
|
|
+ self.task_dao.mark_task_failed(state["task_id"], error_msg)
|
|
|
+ state["result_queries"] = []
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"内容维度分类失败: {e}")
|
|
|
+ state["is_query_type"] = False
|
|
|
+ if state.get("task_id", 0) > 0:
|
|
|
+ self.task_dao.mark_task_failed(state["task_id"], f"分类失败: {str(e)}")
|
|
|
+ state["result_queries"] = []
|
|
|
+
|
|
|
+ return state
|
|
|
+
|
|
|
+ def _route_after_content_classify(self, state: AgentState) -> str:
|
|
|
+ """根据内容分类结果路由:支持的类型 -> EXPAND;不支持 -> UNSUPPORTED"""
|
|
|
+ return "EXPAND" if state.get("is_query_type", False) else "UNSUPPORTED"
|
|
|
+
|
|
|
+ def _expand_content_queries(self, state: AgentState) -> AgentState:
|
|
|
+ """使用QUERY_CLASSIFICATION_PROMPT扩展内容类型的查询词"""
|
|
|
+ question = state["question"]
|
|
|
+ prompt = ChatPromptTemplate.from_messages([
|
|
|
+ SystemMessage(content=QUERY_CLASSIFICATION_PROMPT),
|
|
|
+ HumanMessage(content=question)
|
|
|
+ ])
|
|
|
+ try:
|
|
|
+ response = self.llm.invoke(prompt.format_messages())
|
|
|
+ text = (response.content or "").strip()
|
|
|
+ logger.info(f"查询扩展结果: {text}")
|
|
|
+
|
|
|
+ # 解析JSON结果
|
|
|
+ try:
|
|
|
+ data = json.loads(text)
|
|
|
+ except Exception:
|
|
|
+ data = self._extract_json_from_text(text)
|
|
|
+
|
|
|
+ # 提取所有扩展的查询词
|
|
|
+ expanded = data.get("expanded_queries", {})
|
|
|
+ aggregated: List[str] = []
|
|
|
+
|
|
|
+ # 收集粗颗粒度查询
|
|
|
+ for item in expanded.get("coarse_grained", []) or []:
|
|
|
+ q = str(item.get("query", "")).strip()
|
|
|
+ if q:
|
|
|
+ aggregated.append(q)
|
|
|
+
|
|
|
+ # 收集细颗粒度查询
|
|
|
+ for item in expanded.get("fine_grained", []) or []:
|
|
|
+ q = str(item.get("query", "")).strip()
|
|
|
+ if q:
|
|
|
+ aggregated.append(q)
|
|
|
+
|
|
|
+ # 收集互补或差异化查询
|
|
|
+ for item in expanded.get("complementary_or_differentiated", []) or []:
|
|
|
+ q = str(item.get("query", "")).strip()
|
|
|
+ if q:
|
|
|
+ aggregated.append(q)
|
|
|
+
|
|
|
+ # 去重,保持顺序
|
|
|
+ seen = set()
|
|
|
+ deduped: List[str] = []
|
|
|
+ for q in aggregated:
|
|
|
+ if q not in seen:
|
|
|
+ seen.add(q)
|
|
|
+ deduped.append(q)
|
|
|
+
|
|
|
+ state["initial_queries"] = deduped
|
|
|
+ state["refined_queries"] = deduped
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"查询扩展失败,降级为原始问题: {e}")
|
|
|
+ state["initial_queries"] = [question]
|
|
|
+ state["refined_queries"] = [question]
|
|
|
+
|
|
|
+ return state
|
|
|
|
|
|
def _save_queries(self, state: AgentState) -> AgentState:
|
|
|
"""保存查询词到外部接口节点"""
|
|
|
@@ -259,6 +367,29 @@ class QueryGenerationAgent:
|
|
|
logger.warning(f"LLM分类失败,使用降级策略: {e}")
|
|
|
return [{"query": q, "knowledgeType": "内容知识"} for q in queries]
|
|
|
|
|
|
+ def _extract_json_from_text(self, text: str) -> Dict[str, Any]:
|
|
|
+ """从模型输出中提取JSON对象(可能包含```json代码块或多余文本)"""
|
|
|
+ s = (text or "").strip()
|
|
|
+ # 去除三引号包裹的代码块
|
|
|
+ if s.startswith("```"):
|
|
|
+ # 去掉第一行的 ``` 或 ```json
|
|
|
+ first_newline = s.find('\n')
|
|
|
+ if first_newline != -1:
|
|
|
+ s = s[first_newline + 1:]
|
|
|
+ if s.endswith("```"):
|
|
|
+ s = s[:-3]
|
|
|
+ s = s.strip()
|
|
|
+ # 在文本中查找首个JSON对象
|
|
|
+ import re
|
|
|
+ match = re.search(r"\{[\s\S]*\}", s)
|
|
|
+ if not match:
|
|
|
+ raise ValueError("未找到JSON对象片段")
|
|
|
+ json_str = match.group(0)
|
|
|
+ data = json.loads(json_str)
|
|
|
+ if not isinstance(data, dict):
|
|
|
+ raise ValueError("提取内容不是JSON对象")
|
|
|
+ return data
|
|
|
+
|
|
|
def _extract_json_array_from_text(self, text: str) -> List[Dict[str, Any]]:
|
|
|
"""尽力从模型输出(可能包含```json代码块或多余文本)中提取JSON数组。"""
|
|
|
s = (text or "").strip()
|
|
|
@@ -289,6 +420,7 @@ class QueryGenerationAgent:
|
|
|
Args:
|
|
|
question: 用户问题
|
|
|
task_id: 任务ID
|
|
|
+ knowledge_type: 知识类型(可选,用于兼容)
|
|
|
Returns:
|
|
|
生成的查询词列表
|
|
|
"""
|
|
|
@@ -298,13 +430,16 @@ class QueryGenerationAgent:
|
|
|
"initial_queries": [],
|
|
|
"refined_queries": [],
|
|
|
"result_queries": [],
|
|
|
- "knowledgeType": "工具知识"
|
|
|
+ "knowledgeType": "",
|
|
|
+ "content_dimension": "",
|
|
|
+ "is_query_type": False
|
|
|
}
|
|
|
|
|
|
try:
|
|
|
result = await self.graph.ainvoke(initial_state)
|
|
|
return result["result_queries"]
|
|
|
except Exception as e:
|
|
|
+ logger.error(f"生成查询词失败: {e}")
|
|
|
# 更新任务状态为失败
|
|
|
if task_id > 0:
|
|
|
self.task_dao.update_task_status(task_id, QueryTaskStatus.FAILED)
|