import json from typing import List, Optional from pqai_agent.agent import DEFAULT_MAX_RUN_STEPS from pqai_agent.chat_service import OpenAICompatible from pqai_agent.logging_service import logger from pqai_agent.toolkit.function_tool import FunctionTool class SimpleOpenAICompatibleChatAgent: """ 最简单的多步Agent实现 """ def __init__(self, model: str, system_prompt: str, tools: Optional[List[FunctionTool]] = None, generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None): self.model = model self.llm_client = OpenAICompatible.create_client(model) self.system_prompt = system_prompt self.tools = tools or [] self.tool_map = {tool.name: tool for tool in self.tools} self.generate_cfg = generate_cfg or {} self.max_run_step = max_run_step or DEFAULT_MAX_RUN_STEPS self.tool_call_records = [] def run(self, user_input: str) -> str: messages = [{"role": "system", "content": self.system_prompt}] tools = [tool.get_openai_tool_schema() for tool in self.tools] messages.append({"role": "user", "content": user_input}) n_steps = 0 logger.debug(f"start agent loop. messages: {messages}") while n_steps < self.max_run_step: response = self.llm_client.chat.completions.create(model=self.model, messages=messages, tools=tools, **self.generate_cfg) message = response.choices[0].message messages.append(message) logger.debug(f"current step content: {message.content}") if message.tool_calls: for tool_call in message.tool_calls: function_name = tool_call.function.name arguments = json.loads(tool_call.function.arguments) logger.debug(f"call function[{function_name}], parameter: {arguments}") if function_name in self.tool_map: result = self.tool_map[function_name](**arguments) messages.append({ "role": "tool", "tool_call_id": tool_call.id, "content": json.dumps(result, ensure_ascii=False) }) self.tool_call_records.append({ "name": function_name, "arguments": arguments, "result": result }) else: logger.error(f"Function {function_name} not found in tool map.") raise Exception(f"Function {function_name} not found in tool map.") else: return message.content n_steps += 1 raise Exception("Max run steps exceeded")