agent.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from typing import Annotated
  2. from typing_extensions import TypedDict
  3. from langgraph.graph import StateGraph, START, END
  4. from langgraph.graph.message import add_messages
  5. import os
  6. from langchain.chat_models import init_chat_model
  7. from IPython.display import Image, display
  8. from tools import multiply, add, divide, human_assistance
  9. from langchain_tavily import TavilySearch
  10. from langgraph.prebuilt import ToolNode, tools_condition
  11. from langgraph.checkpoint.memory import InMemorySaver
  12. graph=None
  13. llm_with_tools=None
  14. os.environ["OPENAI_API_KEY"] = "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"
  15. os.environ["TAVILY_API_KEY"] = "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"
  16. class State(TypedDict):
  17. messages: Annotated[list, add_messages]
  18. name: str
  19. birthday: str
  20. def chatbot(state: State):
  21. message = llm_with_tools.invoke(state["messages"])
  22. # Because we will be interrupting during tool execution,
  23. # we disable parallel tool calling to avoid repeating any
  24. # tool invocations when we resume.
  25. assert len(message.tool_calls) <= 1
  26. return {"messages": [message]}
  27. def stream_graph_updates(user_input: str, thread_id: str):
  28. config = {"configurable": {"thread_id": thread_id}}
  29. for event in graph.stream({"messages": [{"role": "user", "content": user_input}]},config,
  30. stream_mode="values"):
  31. for value in event.values():
  32. event["messages"][-1].pretty_print()
  33. def main():
  34. global llm_with_tools, graph
  35. llm = init_chat_model("openai:gpt-4.1")
  36. tool = TavilySearch(max_results=2)
  37. tools=[tool, human_assistance]
  38. llm_with_tools = llm.bind_tools(tools = tools)
  39. # The first argument is the unique node name
  40. # The second argument is the function or object that will be called whenever
  41. # the node is used.
  42. graph_builder = StateGraph(State)
  43. graph_builder.add_node("chatbot", chatbot)
  44. tool_node = ToolNode(tools=tools)
  45. graph_builder.add_node("tools", tool_node)
  46. graph_builder.add_conditional_edges(
  47. "chatbot",
  48. tools_condition,
  49. )
  50. # Any time a tool is called, we return to the chatbot to decide the next step
  51. graph_builder.add_edge("tools", "chatbot")
  52. graph_builder.add_edge(START, "chatbot")
  53. memory = InMemorySaver()
  54. graph = graph_builder.compile(checkpointer=memory)
  55. # 尝试显示图形(需要额外依赖)
  56. try:
  57. graph_image = graph.get_graph().draw_mermaid_png()
  58. with open("graph_visualization.png", "wb") as f:
  59. f.write(graph_image)
  60. print("图形已保存为 'graph_visualization.png'")
  61. except Exception as e:
  62. print(f"无法生成图形: {e}")
  63. thread_id = "1"
  64. stream_graph_updates(("Can you look up when LangGraph was released? "
  65. "When you have the answer, use the human_assistance tool for review."), thread_id)
  66. # while True:
  67. # try:
  68. # user_input = input("User: ")
  69. # if user_input.lower() in ["quit", "exit", "q"]:
  70. # print("Goodbye!")
  71. # break
  72. # stream_graph_updates(user_input)
  73. # except:
  74. # # fallback if input() is not available
  75. # user_input = "What do you know about LangGraph?"
  76. # print("User: " + user_input)
  77. # stream_graph_updates(user_input)
  78. # break
  79. if __name__ == '__main__':
  80. main()