import json from typing import List, Optional from pqai_agent.agent import DEFAULT_MAX_RUN_STEPS from pqai_agent.chat_service import OpenAICompatible from pqai_agent.logging import logger from pqai_agent.toolkit.function_tool import FunctionTool class SimpleOpenAICompatibleChatAgent: """ 最简单的多步Agent实现 """ def __init__(self, model: str, system_prompt: str, tools: Optional[List[FunctionTool]] = None, generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None): self.model = model self.llm_client = OpenAICompatible.create_client(model) self.system_prompt = system_prompt if tools: self.tools = [*tools] else: self.tools = [] self.tool_map = {tool.name: tool for tool in self.tools} self.generate_cfg = generate_cfg or {} self.max_run_step = max_run_step or DEFAULT_MAX_RUN_STEPS self.tool_call_records = [] self.total_input_tokens = 0 self.total_output_tokens = 0 logger.debug(self.tool_map) def add_tool(self, tool: FunctionTool): """添加一个工具到Agent中""" if tool.name in self.tool_map: logger.warning(f"Tool {tool.name} already exists, replacing it.") self.tools.append(tool) self.tool_map[tool.name] = tool def run(self, user_input: str) -> str: messages = [{"role": "system", "content": self.system_prompt}] tools = [tool.get_openai_tool_schema() for tool in self.tools] messages.append({"role": "user", "content": user_input}) n_steps = 0 logger.debug(f"start agent loop. messages: {messages}") while n_steps < self.max_run_step: response = self.llm_client.chat.completions.create(model=self.model, messages=messages, tools=tools, **self.generate_cfg) message = response.choices[0].message self.total_input_tokens += response.usage.prompt_tokens self.total_output_tokens += response.usage.completion_tokens messages.append(message) logger.debug(f"current step content: {message.content}") if message.tool_calls: for tool_call in message.tool_calls: function_name = tool_call.function.name arguments = json.loads(tool_call.function.arguments) logger.debug(f"call function[{function_name}], parameter: {arguments}") if function_name in self.tool_map: result = self.tool_map[function_name](**arguments) messages.append({ "role": "tool", "tool_call_id": tool_call.id, "content": json.dumps(result, ensure_ascii=False) }) self.tool_call_records.append({ "name": function_name, "arguments": arguments, "result": result }) else: logger.error(f"Function {function_name} not found in tool map.") raise Exception(f"Function {function_name} not found in tool map.") else: return message.content n_steps += 1 raise Exception("Max run steps exceeded") def get_total_input_tokens(self) -> int: """获取总输入token数""" return self.total_input_tokens def get_total_output_tokens(self) -> int: """获取总输出token数""" return self.total_output_tokens def get_total_cost(self) -> float: return OpenAICompatible.calculate_cost(self.model, self.total_input_tokens, self.total_output_tokens)