simple_chat_agent.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import json
  2. from typing import List, Optional
  3. from pqai_agent.agent import DEFAULT_MAX_RUN_STEPS
  4. from pqai_agent.chat_service import OpenAICompatible, VOLCENGINE_MODEL_DEEPSEEK_V3
  5. from pqai_agent.logging_service import logger
  6. from pqai_agent.toolkit.function_tool import FunctionTool
  7. from pqai_agent.toolkit.image_describer import ImageDescriber
  8. from pqai_agent.toolkit.message_notifier import MessageNotifier
  9. class SimpleOpenAICompatibleChatAgent:
  10. """ 最简单的多步Agent实现 """
  11. def __init__(self, model: str, system_prompt: str, tools: Optional[List[FunctionTool]] = None,
  12. generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
  13. self.model = model
  14. self.llm_client = OpenAICompatible.create_client(model)
  15. self.system_prompt = system_prompt
  16. self.tools = tools or []
  17. self.tool_map = {tool.name: tool for tool in self.tools}
  18. self.generate_cfg = generate_cfg or {}
  19. self.max_run_step = max_run_step or DEFAULT_MAX_RUN_STEPS
  20. self.tool_call_records = []
  21. def run(self, user_input: str) -> str:
  22. messages = [{"role": "system", "content": self.system_prompt}]
  23. tools = [tool.get_openai_tool_schema() for tool in self.tools]
  24. messages.append({"role": "user", "content": user_input})
  25. n_steps = 0
  26. logger.debug(f"start agent loop. messages: {messages}")
  27. while n_steps < self.max_run_step:
  28. response = self.llm_client.chat.completions.create(model=self.model, messages=messages, tools=tools, **self.generate_cfg)
  29. message = response.choices[0].message
  30. messages.append(message)
  31. logger.debug(f"current step content: {message.content}")
  32. if message.tool_calls:
  33. for tool_call in message.tool_calls:
  34. function_name = tool_call.function.name
  35. arguments = json.loads(tool_call.function.arguments)
  36. logger.debug(f"call function[{function_name}], parameter: {arguments}")
  37. if function_name in self.tool_map:
  38. result = self.tool_map[function_name](**arguments)
  39. messages.append({
  40. "role": "tool",
  41. "tool_call_id": tool_call.id,
  42. "content": json.dumps(result, ensure_ascii=False)
  43. })
  44. self.tool_call_records.append({
  45. "name": function_name,
  46. "arguments": arguments,
  47. "result": result
  48. })
  49. else:
  50. logger.error(f"Function {function_name} not found in tool map.")
  51. raise Exception(f"Function {function_name} not found in tool map.")
  52. else:
  53. return message.content
  54. n_steps += 1
  55. raise Exception("Max run steps exceeded")
  56. if __name__ == '__main__':
  57. import pqai_agent.logging_service
  58. pqai_agent.logging_service.setup_root_logger()
  59. tools = [
  60. *ImageDescriber().get_tools(),
  61. *MessageNotifier().get_tools()
  62. ]
  63. system_instruction = "You are a helpful assistant."
  64. agent = SimpleOpenAICompatibleChatAgent(
  65. model=VOLCENGINE_MODEL_DEEPSEEK_V3,
  66. system_prompt=system_instruction,
  67. tools=tools
  68. )
  69. user_input = query = """
  70. 分析以下图片的内容:"http://wx.qlogo.cn/mmhead/Q3auHgzwzM5glpnBtDUianJErYf9AQsptLM3N78xP3sOR8SSibsG35HQ/0"
  71. 根据内容联想作一首诗
  72. Please think step by step.
  73. """
  74. result = agent.run(user_input)
  75. print(result)