|
|
@@ -469,7 +469,7 @@ class QueryGenerationAgent:
|
|
|
raise ValueError("提取内容不是JSON数组")
|
|
|
return data
|
|
|
|
|
|
- async def generate_queries(self, question: str, need_store: int = 1, task_id: int = 0, knowledge_type: str = "") -> tuple[List[str], str]:
|
|
|
+ async def generate_queries(self, question: str, need_store: int = 1, task_id: int = 0, knowledge_type: str = "") -> tuple[List[str], str, str]:
|
|
|
"""
|
|
|
生成查询词的主入口
|
|
|
|
|
|
@@ -493,7 +493,7 @@ class QueryGenerationAgent:
|
|
|
|
|
|
try:
|
|
|
result = await self.graph.ainvoke(initial_state)
|
|
|
- return result["result_queries"], result["query_type"]
|
|
|
+ return result["result_queries"], result["knowledgeType"], result["query_type"]
|
|
|
except Exception as e:
|
|
|
logger.error(f"生成查询词失败: {e}")
|
|
|
# 更新任务状态为失败
|