|
@@ -46,8 +46,7 @@ prompt="""
|
|
|
|
|
|
---
|
|
|
### 请您按照以下格式提供信息:
|
|
|
-关键词:{query_word}
|
|
|
-请求ID:{request_id}
|
|
|
+{input}
|
|
|
"""
|
|
|
|
|
|
class State(TypedDict):
|
|
@@ -69,7 +68,10 @@ def main():
|
|
|
|
|
|
|
|
|
def execute_agent_with_api(user_input: str):
|
|
|
- global graph, llm_with_tools
|
|
|
+ global graph, llm_with_tools, prompt
|
|
|
+
|
|
|
+ # 替换prompt中的{input}占位符为用户输入
|
|
|
+ formatted_prompt = prompt.replace("{input}", user_input)
|
|
|
|
|
|
# 如果graph或llm_with_tools未初始化,先初始化
|
|
|
if graph is None or llm_with_tools is None:
|
|
@@ -102,9 +104,8 @@ def execute_agent_with_api(user_input: str):
|
|
|
results = []
|
|
|
config = {"configurable": {"thread_id": thread_id}}
|
|
|
|
|
|
-
|
|
|
-
|
|
|
- for event in graph.stream({"messages": [{"role": "user", "content": user_input}]}, config, stream_mode="values"):
|
|
|
+ # 使用格式化后的prompt作为用户输入
|
|
|
+ for event in graph.stream({"messages": [{"role": "user", "content": formatted_prompt}]}, config, stream_mode="values"):
|
|
|
for value in event.values():
|
|
|
# 保存消息内容
|
|
|
if "messages" in event and len(event["messages"]) > 0:
|