agent.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from typing import Annotated
  2. from typing_extensions import TypedDict
  3. from langgraph.graph import StateGraph, START, END
  4. from langgraph.graph.message import add_messages
  5. import os
  6. from langchain_openai import ChatOpenAI
  7. from tools import evaluation_extraction_tool, evaluation_extraction
  8. from langgraph.prebuilt import ToolNode, tools_condition
  9. import requests
  10. from dotenv import load_dotenv
  11. from utils.logging_config import get_logger
  12. # 配置日志
  13. logger = get_logger('CleanAgent')
  14. # 加载环境变量
  15. load_dotenv()
  16. graph=None
  17. llm_with_tools=None
  18. prompt="""
  19. ### 角色 (Role):
  20. 您是一个专业的评估报告检索助手,我的任务是根据用户的查询关键词,从评估报告中提取相关信息。
  21. ### 目标 (Goal):
  22. 1. 根据特定的主题(关键词)快速获取相关的评估报告、数据摘要或关键指标,以便您能深入了解某个方面(如产品表现、服务质量、市场反馈、项目评估等)的详细评估情况。
  23. 2. 为您的每次查询提供一个唯一的标识符,以便您能轻松追踪和管理您的请求,确保数据的可追溯性。
  24. ---
  25. ### 工作流 (Workflow):
  26. 1. 从输入信息中提取关键词(query_word)和请求ID(request_id)
  27. 2. 调用工具evaluation_extraction_tool,进行评估解析
  28. 3. 返回结果
  29. ---
  30. ### 输入信息:
  31. {input}
  32. ### 输出json格式:
  33. {
  34. "requestId":[请求ID],
  35. "status":2
  36. }
  37. """
  38. class State(TypedDict):
  39. messages: Annotated[list, add_messages]
  40. name: str
  41. birthday: str
  42. def chatbot(state: State):
  43. message = llm_with_tools.invoke(state["messages"])
  44. # Because we will be interrupting during tool execution,
  45. # we disable parallel tool calling to avoid repeating any
  46. # tool invocations when we resume.
  47. assert len(message.tool_calls) <= 1
  48. return {"messages": [message]}
  49. def execute_agent_with_api(user_input: str):
  50. # 生成唯一的线程ID
  51. import uuid
  52. thread_id = str(uuid.uuid4())
  53. logger.info(f"开始执行提取,user_input={user_input}, thread_id={thread_id}")
  54. global graph, llm_with_tools, prompt
  55. # 替换prompt中的{input}占位符为用户输入
  56. formatted_prompt = prompt.replace("{input}", user_input)
  57. try:
  58. # 如果graph或llm_with_tools未初始化,先初始化
  59. if graph is None or llm_with_tools is None:
  60. try:
  61. # 使用新版本的 ChatOpenAI
  62. llm = ChatOpenAI(model="gpt-4")
  63. tools = [evaluation_extraction_tool]
  64. llm_with_tools = llm.bind_tools(tools=tools)
  65. # 初始化图
  66. graph_builder = StateGraph(State)
  67. graph_builder.add_node("chatbot", chatbot)
  68. tool_node = ToolNode(tools=tools)
  69. graph_builder.add_node("tools", tool_node)
  70. graph_builder.add_conditional_edges(
  71. "chatbot",
  72. tools_condition,
  73. )
  74. graph_builder.add_edge("tools", "chatbot")
  75. graph_builder.add_edge(START, "chatbot")
  76. # memory = InMemorySaver()
  77. # graph = graph_builder.compile(checkpointer=memory)
  78. graph = graph_builder.compile()
  79. except Exception as e:
  80. logger.error(f"初始化Agent失败: {str(e)}")
  81. return f"初始化Agent失败: {str(e)}"
  82. # 执行Agent并收集结果
  83. results = []
  84. config = {"configurable": {"thread_id": thread_id}}
  85. # 使用格式化后的prompt作为用户输入
  86. for event in graph.stream({"messages": [{"role": "user", "content": formatted_prompt}]}, config, stream_mode="values"):
  87. for value in event.values():
  88. # 保存消息内容
  89. if "messages" in event and len(event["messages"]) > 0:
  90. message = event["messages"][-1]
  91. results.append(message.content)
  92. # 返回结果
  93. res="\n".join(results) if results else "Agent执行完成,但没有返回结果"
  94. logger.info(f"Agent执行完成,返回结果: {res}, thread_id={thread_id}")
  95. return res
  96. except requests.exceptions.ConnectionError as e:
  97. return f"OpenAI API 连接错误: {str(e)}\n请检查网络连接或代理设置。"
  98. except Exception as e:
  99. return f"执行Agent时出错: {str(e)}"
  100. def execute(query_word: str, request_id: str):
  101. logger.info(f"开始处理,request_id: {request_id}, query_word: {query_word}")
  102. result = evaluation_extraction(request_id, query_word)
  103. return result
  104. def main():
  105. print(f"开始执行Agent")
  106. # 设置代理
  107. proxy_url = os.getenv('DYNAMIC_HTTP_PROXY')
  108. if proxy_url:
  109. logger.info(f"设置代理: {proxy_url}")
  110. os.environ["OPENAI_PROXY"] = proxy_url
  111. os.environ["HTTPS_PROXY"] = proxy_url
  112. os.environ["HTTP_PROXY"] = proxy_url
  113. # 执行Agent
  114. # result = execute_agent_with_api('{"query_word":"图文策划方法","request_id":"REQUEST_001"}')
  115. result = execute("图文策划方法", "REQUEST_001")
  116. print(result)
  117. if __name__ == '__main__':
  118. main()