|
@@ -9,7 +9,9 @@ import json
|
|
|
from ..tools.prompts import (
|
|
from ..tools.prompts import (
|
|
|
STRUCTURED_TOOL_DEMAND_PROMPT,
|
|
STRUCTURED_TOOL_DEMAND_PROMPT,
|
|
|
CLASSIFICATION_PROMPT,
|
|
CLASSIFICATION_PROMPT,
|
|
|
- QUERY_CLASSIFICATION_PROMPT
|
|
|
|
|
|
|
+ QUERY_CLASSIFICATION_PROMPT,
|
|
|
|
|
+ WHAT_CLASSIFICATION_PROMPT,
|
|
|
|
|
+ PATTERN_CLASSIFICATION_PROMPT
|
|
|
)
|
|
)
|
|
|
from ..database.models import QueryTaskDAO, QueryTaskStatus, logger
|
|
from ..database.models import QueryTaskDAO, QueryTaskStatus, logger
|
|
|
|
|
|
|
@@ -23,8 +25,7 @@ class AgentState(TypedDict):
|
|
|
refined_queries: List[str]
|
|
refined_queries: List[str]
|
|
|
result_queries: List[Dict[str, str]]
|
|
result_queries: List[Dict[str, str]]
|
|
|
knowledgeType: str
|
|
knowledgeType: str
|
|
|
- content_dimension: str # 内容类型的维度: How / What / Pattern
|
|
|
|
|
- is_query_type: bool # 是否为可处理的查询类型
|
|
|
|
|
|
|
+ query_type: str # 问题类型: How / What / Pattern
|
|
|
|
|
|
|
|
|
|
|
|
|
class QueryGenerationAgent:
|
|
class QueryGenerationAgent:
|
|
@@ -49,6 +50,18 @@ class QueryGenerationAgent:
|
|
|
# 创建状态图
|
|
# 创建状态图
|
|
|
self.graph = self._create_graph()
|
|
self.graph = self._create_graph()
|
|
|
|
|
|
|
|
|
|
+ def _normalize_query_type(self, query_type: str) -> str:
|
|
|
|
|
+ """统一规范化query_type为首字母大写格式(How/What/Pattern)"""
|
|
|
|
|
+ query_type_lower = query_type.strip().lower()
|
|
|
|
|
+ if query_type_lower == "how":
|
|
|
|
|
+ return "How"
|
|
|
|
|
+ elif query_type_lower == "what":
|
|
|
|
|
+ return "What"
|
|
|
|
|
+ elif query_type_lower == "pattern":
|
|
|
|
|
+ return "Pattern"
|
|
|
|
|
+ else:
|
|
|
|
|
+ return query_type # 返回原值
|
|
|
|
|
+
|
|
|
def _create_graph(self) -> StateGraph:
|
|
def _create_graph(self) -> StateGraph:
|
|
|
"""创建LangGraph状态图"""
|
|
"""创建LangGraph状态图"""
|
|
|
workflow = StateGraph(AgentState)
|
|
workflow = StateGraph(AgentState)
|
|
@@ -188,21 +201,13 @@ class QueryGenerationAgent:
|
|
|
data = self._extract_json_from_text(text)
|
|
data = self._extract_json_from_text(text)
|
|
|
|
|
|
|
|
dimension = data.get("所属维度", "").strip()
|
|
dimension = data.get("所属维度", "").strip()
|
|
|
- state["content_dimension"] = dimension
|
|
|
|
|
-
|
|
|
|
|
- # 判断是否为可处理的查询类型(目前仅支持How类型)
|
|
|
|
|
- state["is_query_type"] = dimension == "How"
|
|
|
|
|
|
|
+ # 统一为首字母大写格式(How/What/Pattern)
|
|
|
|
|
+ dimension = self._normalize_query_type(dimension)
|
|
|
|
|
+ state["query_type"] = dimension
|
|
|
|
|
+ logger.info(f"问题类型设置为: {dimension}")
|
|
|
|
|
|
|
|
- if not state["is_query_type"]:
|
|
|
|
|
- # 不支持的类型,标记任务失败
|
|
|
|
|
- error_msg = f"暂不支持{dimension}类型的内容问题,当前仅支持How类型"
|
|
|
|
|
- logger.info(error_msg)
|
|
|
|
|
- if state.get("task_id", 0) > 0:
|
|
|
|
|
- self.task_dao.mark_task_failed(state["task_id"], error_msg)
|
|
|
|
|
- state["result_queries"] = []
|
|
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
logger.error(f"内容维度分类失败: {e}")
|
|
logger.error(f"内容维度分类失败: {e}")
|
|
|
- state["is_query_type"] = False
|
|
|
|
|
if state.get("task_id", 0) > 0:
|
|
if state.get("task_id", 0) > 0:
|
|
|
self.task_dao.mark_task_failed(state["task_id"], f"分类失败: {str(e)}")
|
|
self.task_dao.mark_task_failed(state["task_id"], f"分类失败: {str(e)}")
|
|
|
state["result_queries"] = []
|
|
state["result_queries"] = []
|
|
@@ -210,14 +215,37 @@ class QueryGenerationAgent:
|
|
|
return state
|
|
return state
|
|
|
|
|
|
|
|
def _route_after_content_classify(self, state: AgentState) -> str:
|
|
def _route_after_content_classify(self, state: AgentState) -> str:
|
|
|
- """根据内容分类结果路由:支持的类型 -> EXPAND;不支持 -> UNSUPPORTED"""
|
|
|
|
|
- return "EXPAND" if state.get("is_query_type", False) else "UNSUPPORTED"
|
|
|
|
|
|
|
+ """根据内容分类结果路由:所有类型都支持扩展"""
|
|
|
|
|
+ query_type = state.get("query_type", "")
|
|
|
|
|
+ # 支持 How / What / Pattern 三种类型
|
|
|
|
|
+ if query_type in ["How", "What", "Pattern"]:
|
|
|
|
|
+ return "EXPAND"
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 未识别的类型,不支持
|
|
|
|
|
+ logger.warning(f"未识别的问题类型: {query_type}")
|
|
|
|
|
+ return "UNSUPPORTED"
|
|
|
|
|
|
|
|
def _expand_content_queries(self, state: AgentState) -> AgentState:
|
|
def _expand_content_queries(self, state: AgentState) -> AgentState:
|
|
|
- """使用QUERY_CLASSIFICATION_PROMPT扩展内容类型的查询词"""
|
|
|
|
|
|
|
+ """根据问题类型选择相应的PROMPT扩展内容查询词"""
|
|
|
question = state["question"]
|
|
question = state["question"]
|
|
|
|
|
+ query_type = state.get("query_type", "How")
|
|
|
|
|
+
|
|
|
|
|
+ # 根据query_type选择对应的PROMPT(值已在分类阶段规范化为How/What/Pattern)
|
|
|
|
|
+ if query_type == "How":
|
|
|
|
|
+ classification_prompt = QUERY_CLASSIFICATION_PROMPT
|
|
|
|
|
+ elif query_type == "What":
|
|
|
|
|
+ classification_prompt = WHAT_CLASSIFICATION_PROMPT
|
|
|
|
|
+ elif query_type == "Pattern":
|
|
|
|
|
+ classification_prompt = PATTERN_CLASSIFICATION_PROMPT
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 默认使用How类型的PROMPT
|
|
|
|
|
+ classification_prompt = QUERY_CLASSIFICATION_PROMPT
|
|
|
|
|
+ logger.warning(f"未识别的问题类型 {query_type},使用默认How类型PROMPT")
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"使用{query_type}类型的PROMPT进行查询扩展")
|
|
|
|
|
+
|
|
|
prompt = ChatPromptTemplate.from_messages([
|
|
prompt = ChatPromptTemplate.from_messages([
|
|
|
- SystemMessage(content=QUERY_CLASSIFICATION_PROMPT),
|
|
|
|
|
|
|
+ SystemMessage(content=classification_prompt),
|
|
|
HumanMessage(content=question)
|
|
HumanMessage(content=question)
|
|
|
])
|
|
])
|
|
|
try:
|
|
try:
|
|
@@ -460,8 +488,7 @@ class QueryGenerationAgent:
|
|
|
"refined_queries": [],
|
|
"refined_queries": [],
|
|
|
"result_queries": [],
|
|
"result_queries": [],
|
|
|
"knowledgeType": "",
|
|
"knowledgeType": "",
|
|
|
- "content_dimension": "",
|
|
|
|
|
- "is_query_type": False
|
|
|
|
|
|
|
+ "query_type": ""
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
try:
|
|
try:
|