|
|
@@ -18,6 +18,7 @@ class AgentState(TypedDict):
|
|
|
"""Agent状态定义"""
|
|
|
question: str
|
|
|
task_id: int
|
|
|
+ need_store: int
|
|
|
initial_queries: List[str]
|
|
|
refined_queries: List[str]
|
|
|
result_queries: List[Dict[str, str]]
|
|
|
@@ -311,20 +312,22 @@ class QueryGenerationAgent:
|
|
|
]
|
|
|
state["result_queries"] = result_items
|
|
|
|
|
|
- try:
|
|
|
- url = "http://aigc-testapi.cybertogether.net/aigc/agent/knowledgeWorkflow/addQuery"
|
|
|
- headers = {"Content-Type": "application/json"}
|
|
|
- with httpx.Client() as client:
|
|
|
- data_content = result_items
|
|
|
- logger.info(f"查询词保存数据: {data_content}")
|
|
|
- resp1 = client.post(url, headers=headers, json=data_content, timeout=30)
|
|
|
- resp1.raise_for_status()
|
|
|
- logger.info(f"查询词保存结果: {resp1.text}")
|
|
|
- logger.info(f"查询词保存成功: question={question},query数量={len(result_items)}")
|
|
|
- except httpx.HTTPError as e:
|
|
|
- logger.error(f"保存查询词时发生HTTP错误: {str(e)}")
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"保存查询词时发生错误: {str(e)}")
|
|
|
+ # need_store=1 保存查询词
|
|
|
+ if state.get("need_store", 1) == 1:
|
|
|
+ try:
|
|
|
+ url = "http://aigc-testapi.cybertogether.net/aigc/agent/knowledgeWorkflow/addQuery"
|
|
|
+ headers = {"Content-Type": "application/json"}
|
|
|
+ with httpx.Client() as client:
|
|
|
+ data_content = result_items
|
|
|
+ logger.info(f"查询词保存数据: {data_content}")
|
|
|
+ resp1 = client.post(url, headers=headers, json=data_content, timeout=30)
|
|
|
+ resp1.raise_for_status()
|
|
|
+ logger.info(f"查询词保存结果: {resp1.text}")
|
|
|
+ logger.info(f"查询词保存成功: question={question},query数量={len(result_items)}")
|
|
|
+ except httpx.HTTPError as e:
|
|
|
+ logger.error(f"保存查询词时发生HTTP错误: {str(e)}")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"保存查询词时发生错误: {str(e)}")
|
|
|
|
|
|
return state
|
|
|
|
|
|
@@ -438,7 +441,7 @@ class QueryGenerationAgent:
|
|
|
raise ValueError("提取内容不是JSON数组")
|
|
|
return data
|
|
|
|
|
|
- async def generate_queries(self, question: str, task_id: int = 0, knowledge_type: str = "") -> List[str]:
|
|
|
+ async def generate_queries(self, question: str, need_store: int = 1, task_id: int = 0, knowledge_type: str = "") -> List[str]:
|
|
|
"""
|
|
|
生成查询词的主入口
|
|
|
|
|
|
@@ -452,6 +455,7 @@ class QueryGenerationAgent:
|
|
|
initial_state = {
|
|
|
"question": question,
|
|
|
"task_id": task_id,
|
|
|
+ "need_store": need_store,
|
|
|
"initial_queries": [],
|
|
|
"refined_queries": [],
|
|
|
"result_queries": [],
|