|
@@ -0,0 +1,99 @@
|
|
|
+from typing import Annotated
|
|
|
+from typing_extensions import TypedDict
|
|
|
+from langgraph.graph import StateGraph, START, END
|
|
|
+from langgraph.graph.message import add_messages
|
|
|
+import os
|
|
|
+from langchain.chat_models import init_chat_model
|
|
|
+from IPython.display import Image, display
|
|
|
+from tools import multiply, add, divide, human_assistance
|
|
|
+from langchain_tavily import TavilySearch
|
|
|
+
|
|
|
+from langgraph.prebuilt import ToolNode, tools_condition
|
|
|
+from langgraph.checkpoint.memory import InMemorySaver
|
|
|
+
|
|
|
+graph=None
|
|
|
+llm_with_tools=None
|
|
|
+os.environ["OPENAI_API_KEY"] = "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"
|
|
|
+os.environ["TAVILY_API_KEY"] = "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"
|
|
|
+
|
|
|
+class State(TypedDict):
|
|
|
+ messages: Annotated[list, add_messages]
|
|
|
+ name: str
|
|
|
+ birthday: str
|
|
|
+
|
|
|
+
|
|
|
+def chatbot(state: State):
|
|
|
+ message = llm_with_tools.invoke(state["messages"])
|
|
|
+ # Because we will be interrupting during tool execution,
|
|
|
+ # we disable parallel tool calling to avoid repeating any
|
|
|
+ # tool invocations when we resume.
|
|
|
+ assert len(message.tool_calls) <= 1
|
|
|
+ return {"messages": [message]}
|
|
|
+
|
|
|
+def stream_graph_updates(user_input: str, thread_id: str):
|
|
|
+ config = {"configurable": {"thread_id": thread_id}}
|
|
|
+ for event in graph.stream({"messages": [{"role": "user", "content": user_input}]},config,
|
|
|
+ stream_mode="values"):
|
|
|
+ for value in event.values():
|
|
|
+ event["messages"][-1].pretty_print()
|
|
|
+
|
|
|
+def main():
|
|
|
+
|
|
|
+ global llm_with_tools, graph
|
|
|
+
|
|
|
+
|
|
|
+ llm = init_chat_model("openai:gpt-4.1")
|
|
|
+ tool = TavilySearch(max_results=2)
|
|
|
+ tools=[tool, human_assistance]
|
|
|
+
|
|
|
+
|
|
|
+ llm_with_tools = llm.bind_tools(tools = tools)
|
|
|
+ # The first argument is the unique node name
|
|
|
+ # The second argument is the function or object that will be called whenever
|
|
|
+ # the node is used.
|
|
|
+
|
|
|
+ graph_builder = StateGraph(State)
|
|
|
+ graph_builder.add_node("chatbot", chatbot)
|
|
|
+
|
|
|
+ tool_node = ToolNode(tools=tools)
|
|
|
+ graph_builder.add_node("tools", tool_node)
|
|
|
+
|
|
|
+ graph_builder.add_conditional_edges(
|
|
|
+ "chatbot",
|
|
|
+ tools_condition,
|
|
|
+ )
|
|
|
+ # Any time a tool is called, we return to the chatbot to decide the next step
|
|
|
+ graph_builder.add_edge("tools", "chatbot")
|
|
|
+ graph_builder.add_edge(START, "chatbot")
|
|
|
+
|
|
|
+ memory = InMemorySaver()
|
|
|
+ graph = graph_builder.compile(checkpointer=memory)
|
|
|
+
|
|
|
+ # 尝试显示图形(需要额外依赖)
|
|
|
+ try:
|
|
|
+ graph_image = graph.get_graph().draw_mermaid_png()
|
|
|
+ with open("graph_visualization.png", "wb") as f:
|
|
|
+ f.write(graph_image)
|
|
|
+ print("图形已保存为 'graph_visualization.png'")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"无法生成图形: {e}")
|
|
|
+ thread_id = "1"
|
|
|
+ stream_graph_updates(("Can you look up when LangGraph was released? "
|
|
|
+ "When you have the answer, use the human_assistance tool for review."), thread_id)
|
|
|
+ # while True:
|
|
|
+ # try:
|
|
|
+ # user_input = input("User: ")
|
|
|
+ # if user_input.lower() in ["quit", "exit", "q"]:
|
|
|
+ # print("Goodbye!")
|
|
|
+ # break
|
|
|
+ # stream_graph_updates(user_input)
|
|
|
+ # except:
|
|
|
+ # # fallback if input() is not available
|
|
|
+ # user_input = "What do you know about LangGraph?"
|
|
|
+ # print("User: " + user_input)
|
|
|
+ # stream_graph_updates(user_input)
|
|
|
+ # break
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ main()
|