123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- from typing import Annotated
- from typing_extensions import TypedDict
- from langgraph.graph import StateGraph, START, END
- from langgraph.graph.message import add_messages
- import os
- from langchain.chat_models import init_chat_model
- from IPython.display import Image, display
- from tools import multiply, add, divide, human_assistance
- from langchain_tavily import TavilySearch
- from langgraph.prebuilt import ToolNode, tools_condition
- from langgraph.checkpoint.memory import InMemorySaver
- graph=None
- llm_with_tools=None
- os.environ["OPENAI_API_KEY"] = "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"
- os.environ["TAVILY_API_KEY"] = "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"
- class State(TypedDict):
- messages: Annotated[list, add_messages]
- name: str
- birthday: str
- def chatbot(state: State):
- message = llm_with_tools.invoke(state["messages"])
- # Because we will be interrupting during tool execution,
- # we disable parallel tool calling to avoid repeating any
- # tool invocations when we resume.
- assert len(message.tool_calls) <= 1
- return {"messages": [message]}
- def stream_graph_updates(user_input: str, thread_id: str):
- config = {"configurable": {"thread_id": thread_id}}
- for event in graph.stream({"messages": [{"role": "user", "content": user_input}]},config,
- stream_mode="values"):
- for value in event.values():
- event["messages"][-1].pretty_print()
- def main():
- global llm_with_tools, graph
- llm = init_chat_model("openai:gpt-4.1")
- tool = TavilySearch(max_results=2)
- tools=[tool, human_assistance]
-
- llm_with_tools = llm.bind_tools(tools = tools)
- # The first argument is the unique node name
- # The second argument is the function or object that will be called whenever
- # the node is used.
- graph_builder = StateGraph(State)
- graph_builder.add_node("chatbot", chatbot)
- tool_node = ToolNode(tools=tools)
- graph_builder.add_node("tools", tool_node)
- graph_builder.add_conditional_edges(
- "chatbot",
- tools_condition,
- )
- # Any time a tool is called, we return to the chatbot to decide the next step
- graph_builder.add_edge("tools", "chatbot")
- graph_builder.add_edge(START, "chatbot")
- memory = InMemorySaver()
- graph = graph_builder.compile(checkpointer=memory)
- # 尝试显示图形(需要额外依赖)
- try:
- graph_image = graph.get_graph().draw_mermaid_png()
- with open("graph_visualization.png", "wb") as f:
- f.write(graph_image)
- print("图形已保存为 'graph_visualization.png'")
- except Exception as e:
- print(f"无法生成图形: {e}")
- thread_id = "1"
- stream_graph_updates(("Can you look up when LangGraph was released? "
- "When you have the answer, use the human_assistance tool for review."), thread_id)
- # while True:
- # try:
- # user_input = input("User: ")
- # if user_input.lower() in ["quit", "exit", "q"]:
- # print("Goodbye!")
- # break
- # stream_graph_updates(user_input)
- # except:
- # # fallback if input() is not available
- # user_input = "What do you know about LangGraph?"
- # print("User: " + user_input)
- # stream_graph_updates(user_input)
- # break
- if __name__ == '__main__':
- main()
|