simple_chat_agent.py 3.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import json
  2. from typing import List, Optional
  3. from pqai_agent.agent import DEFAULT_MAX_RUN_STEPS
  4. from pqai_agent.chat_service import OpenAICompatible
  5. from pqai_agent.logging import logger
  6. from pqai_agent.toolkit.function_tool import FunctionTool
  7. class SimpleOpenAICompatibleChatAgent:
  8. """ 最简单的多步Agent实现 """
  9. def __init__(self, model: str, system_prompt: str, tools: Optional[List[FunctionTool]] = None,
  10. generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
  11. self.model = model
  12. self.llm_client = OpenAICompatible.create_client(model)
  13. self.system_prompt = system_prompt
  14. if tools:
  15. self.tools = [*tools]
  16. else:
  17. self.tools = []
  18. self.tool_map = {tool.name: tool for tool in self.tools}
  19. self.generate_cfg = generate_cfg or {}
  20. self.max_run_step = max_run_step or DEFAULT_MAX_RUN_STEPS
  21. self.tool_call_records = []
  22. self.total_input_tokens = 0
  23. self.total_output_tokens = 0
  24. logger.debug(self.tool_map)
  25. def add_tool(self, tool: FunctionTool):
  26. """添加一个工具到Agent中"""
  27. if tool.name in self.tool_map:
  28. logger.warning(f"Tool {tool.name} already exists, replacing it.")
  29. self.tools.append(tool)
  30. self.tool_map[tool.name] = tool
  31. def run(self, user_input: str) -> str:
  32. messages = [{"role": "system", "content": self.system_prompt}]
  33. tools = [tool.get_openai_tool_schema() for tool in self.tools]
  34. messages.append({"role": "user", "content": user_input})
  35. n_steps = 0
  36. logger.debug(f"start agent loop. messages: {messages}")
  37. while n_steps < self.max_run_step:
  38. response = self.llm_client.chat.completions.create(model=self.model, messages=messages, tools=tools, **self.generate_cfg)
  39. message = response.choices[0].message
  40. self.total_input_tokens += response.usage.prompt_tokens
  41. self.total_output_tokens += response.usage.completion_tokens
  42. messages.append(message)
  43. logger.debug(f"current step content: {message.content}")
  44. if message.tool_calls:
  45. for tool_call in message.tool_calls:
  46. function_name = tool_call.function.name
  47. arguments = json.loads(tool_call.function.arguments)
  48. logger.debug(f"call function[{function_name}], parameter: {arguments}")
  49. if function_name in self.tool_map:
  50. result = self.tool_map[function_name](**arguments)
  51. messages.append({
  52. "role": "tool",
  53. "tool_call_id": tool_call.id,
  54. "content": json.dumps(result, ensure_ascii=False)
  55. })
  56. self.tool_call_records.append({
  57. "name": function_name,
  58. "arguments": arguments,
  59. "result": result
  60. })
  61. else:
  62. logger.error(f"Function {function_name} not found in tool map.")
  63. raise Exception(f"Function {function_name} not found in tool map.")
  64. else:
  65. return message.content
  66. n_steps += 1
  67. raise Exception("Max run steps exceeded")
  68. def get_total_input_tokens(self) -> int:
  69. """获取总输入token数"""
  70. return self.total_input_tokens
  71. def get_total_output_tokens(self) -> int:
  72. """获取总输出token数"""
  73. return self.total_output_tokens
  74. def get_total_cost(self) -> float:
  75. return OpenAICompatible.calculate_cost(self.model, self.total_input_tokens, self.total_output_tokens)