simple_chat_agent.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import json
  2. from typing import List, Optional
  3. import pqai_agent.utils
  4. from pqai_agent.agent import DEFAULT_MAX_RUN_STEPS
  5. from pqai_agent.chat_service import OpenAICompatible
  6. from pqai_agent.data_models.agent_task_detail import AgentTaskDetail
  7. from pqai_agent.logging import logger
  8. from pqai_agent.toolkit.function_tool import FunctionTool
  9. from pqai_agent_server.const.status_enum import AgentTaskDetailStatus
  10. class SimpleOpenAICompatibleChatAgent:
  11. """ 最简单的多步Agent实现 """
  12. def __init__(self, model: str, system_prompt: str, tools: Optional[List[FunctionTool]] = None,
  13. generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
  14. self.model = model
  15. self.llm_client = OpenAICompatible.create_client(model)
  16. self.system_prompt = system_prompt
  17. if tools:
  18. self.tools = [*tools]
  19. else:
  20. self.tools = []
  21. self.tool_map = {tool.name: tool for tool in self.tools}
  22. self.generate_cfg = generate_cfg or {}
  23. self.max_run_step = max_run_step or DEFAULT_MAX_RUN_STEPS
  24. self.tool_call_records = []
  25. self.agent_task_details: list[AgentTaskDetail] = []
  26. self.total_input_tokens = 0
  27. self.total_output_tokens = 0
  28. logger.debug(self.tool_map)
  29. def add_tool(self, tool: FunctionTool):
  30. """添加一个工具到Agent中"""
  31. if tool.name in self.tool_map:
  32. logger.warning(f"Tool {tool.name} already exists, replacing it.")
  33. self.tools.append(tool)
  34. self.tool_map[tool.name] = tool
  35. def run(self, user_input: str) -> str:
  36. run_id = pqai_agent.utils.random_str()[:12]
  37. messages = [{"role": "system", "content": self.system_prompt}]
  38. tools = [tool.get_openai_tool_schema() for tool in self.tools]
  39. messages.append({"role": "user", "content": user_input})
  40. n_steps = 0
  41. logger.debug(f"run_id[{run_id}] start agent loop. messages: {messages}")
  42. while n_steps < self.max_run_step:
  43. response = self.llm_client.chat.completions.create(model=self.model, messages=messages, tools=tools, **self.generate_cfg)
  44. message = response.choices[0].message
  45. self.total_input_tokens += response.usage.prompt_tokens
  46. self.total_output_tokens += response.usage.completion_tokens
  47. messages.append(message)
  48. logger.debug(f"run_id[{run_id}] current step content: {message.content}")
  49. if message.tool_calls:
  50. for tool_call in message.tool_calls:
  51. function_name = tool_call.function.name
  52. arguments = json.loads(tool_call.function.arguments)
  53. logger.debug(f"run_id[{run_id}] call function[{function_name}], parameter: {arguments}")
  54. agent_task_detail = AgentTaskDetail()
  55. agent_task_detail.executor_type = 'tool'
  56. agent_task_detail.executor_name = function_name
  57. agent_task_detail.input_data = tool_call.function.arguments
  58. self.agent_task_details.append(agent_task_detail)
  59. if function_name in self.tool_map:
  60. result = self.tool_map[function_name](**arguments)
  61. messages.append({
  62. "role": "tool",
  63. "tool_call_id": tool_call.id,
  64. "content": json.dumps(result, ensure_ascii=False)
  65. })
  66. self.tool_call_records.append({
  67. "name": function_name,
  68. "arguments": arguments,
  69. "result": result
  70. })
  71. agent_task_detail.output_data = json.dumps(result, ensure_ascii=False)
  72. agent_task_detail.status = AgentTaskDetailStatus.SUCCESS.value
  73. else:
  74. agent_task_detail.error_message = f"Function {function_name} not found in tool map."
  75. agent_task_detail.status = AgentTaskDetailStatus.FAILED.value
  76. logger.error(f"run_id[{run_id}] Function {function_name} not found in tool map.")
  77. raise Exception(f"Function {function_name} not found in tool map.")
  78. else:
  79. agent_task_detail = AgentTaskDetail()
  80. agent_task_detail.executor_type = 'llm'
  81. agent_task_detail.executor_name = self.model
  82. agent_task_detail.output_data = message.content
  83. agent_task_detail.status = AgentTaskDetailStatus.SUCCESS.value
  84. self.agent_task_details.append(agent_task_detail)
  85. return message.content
  86. n_steps += 1
  87. raise Exception("Max run steps exceeded")
  88. # 新增方法:获取步骤记录
  89. def get_agent_task_details(self) -> list:
  90. """返回代理运行过程中的详细步骤记录"""
  91. return self.agent_task_details
  92. def get_total_input_tokens(self) -> int:
  93. """获取总输入token数"""
  94. return self.total_input_tokens
  95. def get_total_output_tokens(self) -> int:
  96. """获取总输出token数"""
  97. return self.total_output_tokens
  98. def get_total_cost(self) -> float:
  99. return OpenAICompatible.calculate_cost(self.model, self.total_input_tokens, self.total_output_tokens)