Browse Source

Merge branch 'main' of https://git.yishihui.com/ai/knowledge-agent

jihuaqiang 1 week ago
parent
commit
19170c14dd
4 changed files with 170 additions and 2 deletions
  1. 1 0
      .gitignore
  2. 99 0
      agents/clean_agent/agent.py
  3. 67 0
      agents/clean_agent/tools.py
  4. 3 2
      requirements.txt

+ 1 - 0
.gitignore

@@ -289,6 +289,7 @@ pyrightconfig.json
 # custom
 myenv/
 *.pid
+*.png
 
 
 # End of https://www.toptal.com/developers/gitignore/api/python,pycharm

+ 99 - 0
agents/clean_agent/agent.py

@@ -0,0 +1,99 @@
+from typing import Annotated
+from typing_extensions import TypedDict
+from langgraph.graph import StateGraph, START, END
+from langgraph.graph.message import add_messages
+import os
+from langchain.chat_models import init_chat_model
+from IPython.display import Image, display
+from tools import multiply, add, divide, human_assistance
+from langchain_tavily import TavilySearch
+
+from langgraph.prebuilt import ToolNode, tools_condition
+from langgraph.checkpoint.memory import InMemorySaver
+
+graph=None
+llm_with_tools=None
+os.environ["OPENAI_API_KEY"] = "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"
+os.environ["TAVILY_API_KEY"] = "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"
+
+class State(TypedDict):
+    messages: Annotated[list, add_messages]
+    name: str
+    birthday: str
+
+
+def chatbot(state: State):
+    message = llm_with_tools.invoke(state["messages"])
+    # Because we will be interrupting during tool execution,
+    # we disable parallel tool calling to avoid repeating any
+    # tool invocations when we resume.
+    assert len(message.tool_calls) <= 1
+    return {"messages": [message]}
+
+def stream_graph_updates(user_input: str, thread_id: str):
+    config = {"configurable": {"thread_id": thread_id}}
+    for event in graph.stream({"messages": [{"role": "user", "content": user_input}]},config,
+    stream_mode="values"):
+        for value in event.values():
+            event["messages"][-1].pretty_print()
+
+def main():
+
+    global llm_with_tools, graph
+
+
+    llm = init_chat_model("openai:gpt-4.1")
+    tool = TavilySearch(max_results=2)
+    tools=[tool, human_assistance]
+    
+
+    llm_with_tools = llm.bind_tools(tools = tools)
+    # The first argument is the unique node name
+    # The second argument is the function or object that will be called whenever
+    # the node is used.
+
+    graph_builder = StateGraph(State)
+    graph_builder.add_node("chatbot", chatbot)
+
+    tool_node = ToolNode(tools=tools)
+    graph_builder.add_node("tools", tool_node)
+
+    graph_builder.add_conditional_edges(
+        "chatbot",
+        tools_condition,
+    )
+    # Any time a tool is called, we return to the chatbot to decide the next step
+    graph_builder.add_edge("tools", "chatbot")
+    graph_builder.add_edge(START, "chatbot")
+
+    memory = InMemorySaver()
+    graph = graph_builder.compile(checkpointer=memory)
+
+    # 尝试显示图形(需要额外依赖)
+    try:
+        graph_image = graph.get_graph().draw_mermaid_png()
+        with open("graph_visualization.png", "wb") as f:
+            f.write(graph_image)
+        print("图形已保存为 'graph_visualization.png'")
+    except Exception as e:
+        print(f"无法生成图形: {e}")
+    thread_id = "1"
+    stream_graph_updates(("Can you look up when LangGraph was released? "
+    "When you have the answer, use the human_assistance tool for review."), thread_id)
+    # while True:
+    #     try:
+    #         user_input = input("User: ")
+    #         if user_input.lower() in ["quit", "exit", "q"]:
+    #             print("Goodbye!")
+    #             break
+    #         stream_graph_updates(user_input)
+    #     except:
+    #         # fallback if input() is not available
+    #         user_input = "What do you know about LangGraph?"
+    #         print("User: " + user_input)
+    #         stream_graph_updates(user_input)
+    #         break
+
+
+if __name__ == '__main__':
+    main()

+ 67 - 0
agents/clean_agent/tools.py

@@ -0,0 +1,67 @@
+from langchain_core.tools import tool
+from typing import Annotated
+from langchain_core.messages import ToolMessage
+from langchain_core.tools import InjectedToolCallId, tool
+
+from langgraph.types import Command, interrupt
+
+# Define tools
+@tool
+def multiply(a: int, b: int) -> int:
+    """Multiply a and b.
+
+    Args:
+        a: first int
+        b: second int
+    """
+    return a * b
+
+
+@tool
+def add(a: int, b: int) -> int:
+    """Adds a and b.
+
+    Args:
+        a: first int
+        b: second int
+    """
+    return a + b
+
+
+@tool
+def divide(a: int, b: int) -> float:
+    """Divide a and b.
+
+    Args:
+        a: first int
+        b: second int
+    """
+    return a / b
+
+@tool
+def human_assistance(
+    name: str, birthday: str, tool_call_id: Annotated[str, InjectedToolCallId]
+) -> str:
+    """Request assistance from a human."""
+    human_response = interrupt(
+        {
+            "question": "Is this correct?",
+            "name": name,
+            "birthday": birthday,
+        },
+    )
+    if human_response.get("correct", "").lower().startswith("y"):
+        verified_name = name
+        verified_birthday = birthday
+        response = "Correct"
+    else:
+        verified_name = human_response.get("name", name)
+        verified_birthday = human_response.get("birthday", birthday)
+        response = f"Made a correction: {human_response}"
+
+    state_update = {
+        "name": verified_name,
+        "birthday": verified_birthday,
+        "messages": [ToolMessage(response, tool_call_id=tool_call_id)],
+    }
+    return Command(update=state_update)

+ 3 - 2
requirements.txt

@@ -10,5 +10,6 @@ requests==2.32.4
 fastapi>=0.116.0
 uvicorn[standard]>=0.35.0
 
-# LangGraph 相关依赖(可选)
-langgraph>=0.2.0
+langgraph==0.6.6
+langsmith==0.4.16
+langchain-openai==0.3.31