|
@@ -469,7 +469,7 @@ class QueryGenerationAgent:
|
|
|
raise ValueError("提取内容不是JSON数组")
|
|
raise ValueError("提取内容不是JSON数组")
|
|
|
return data
|
|
return data
|
|
|
|
|
|
|
|
- async def generate_queries(self, question: str, need_store: int = 1, task_id: int = 0, knowledge_type: str = "") -> List[str]:
|
|
|
|
|
|
|
+ async def generate_queries(self, question: str, need_store: int = 1, task_id: int = 0, knowledge_type: str = "") -> tuple[List[str], str]:
|
|
|
"""
|
|
"""
|
|
|
生成查询词的主入口
|
|
生成查询词的主入口
|
|
|
|
|
|
|
@@ -478,7 +478,7 @@ class QueryGenerationAgent:
|
|
|
task_id: 任务ID
|
|
task_id: 任务ID
|
|
|
knowledge_type: 知识类型(可选,用于兼容)
|
|
knowledge_type: 知识类型(可选,用于兼容)
|
|
|
Returns:
|
|
Returns:
|
|
|
- 生成的查询词列表
|
|
|
|
|
|
|
+ 元组:(生成的查询词列表, 问题类型)
|
|
|
"""
|
|
"""
|
|
|
initial_state = {
|
|
initial_state = {
|
|
|
"question": question,
|
|
"question": question,
|
|
@@ -493,14 +493,14 @@ class QueryGenerationAgent:
|
|
|
|
|
|
|
|
try:
|
|
try:
|
|
|
result = await self.graph.ainvoke(initial_state)
|
|
result = await self.graph.ainvoke(initial_state)
|
|
|
- return result["result_queries"]
|
|
|
|
|
|
|
+ return result["result_queries"], result["query_type"]
|
|
|
except Exception as e:
|
|
except Exception as e:
|
|
|
logger.error(f"生成查询词失败: {e}")
|
|
logger.error(f"生成查询词失败: {e}")
|
|
|
# 更新任务状态为失败
|
|
# 更新任务状态为失败
|
|
|
if task_id > 0:
|
|
if task_id > 0:
|
|
|
self.task_dao.update_task_status(task_id, QueryTaskStatus.FAILED)
|
|
self.task_dao.update_task_status(task_id, QueryTaskStatus.FAILED)
|
|
|
# 降级处理:返回原始问题
|
|
# 降级处理:返回原始问题
|
|
|
- return [question]
|
|
|
|
|
|
|
+ return [question], "How" # 默认返回How类型
|
|
|
|
|
|
|
|
def is_tool_question(self, question: str) -> bool:
|
|
def is_tool_question(self, question: str) -> bool:
|
|
|
"""同步判断问题是否为工具知识类型。"""
|
|
"""同步判断问题是否为工具知识类型。"""
|