12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- import json
- from typing import List, Optional
- from pqai_agent.agent import DEFAULT_MAX_RUN_STEPS
- from pqai_agent.chat_service import OpenAICompatible
- from pqai_agent.logging_service import logger
- from pqai_agent.toolkit.function_tool import FunctionTool
- class SimpleOpenAICompatibleChatAgent:
- """ 最简单的多步Agent实现 """
- def __init__(self, model: str, system_prompt: str, tools: Optional[List[FunctionTool]] = None,
- generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
- self.model = model
- self.llm_client = OpenAICompatible.create_client(model)
- self.system_prompt = system_prompt
- if tools:
- self.tools = [*tools]
- else:
- self.tools = []
- self.tool_map = {tool.name: tool for tool in self.tools}
- self.generate_cfg = generate_cfg or {}
- self.max_run_step = max_run_step or DEFAULT_MAX_RUN_STEPS
- self.tool_call_records = []
- logger.debug(self.tool_map)
- def add_tool(self, tool: FunctionTool):
- """添加一个工具到Agent中"""
- if tool.name in self.tool_map:
- logger.warning(f"Tool {tool.name} already exists, replacing it.")
- self.tools.append(tool)
- self.tool_map[tool.name] = tool
- def run(self, user_input: str) -> str:
- messages = [{"role": "system", "content": self.system_prompt}]
- tools = [tool.get_openai_tool_schema() for tool in self.tools]
- messages.append({"role": "user", "content": user_input})
- n_steps = 0
- logger.debug(f"start agent loop. messages: {messages}")
- while n_steps < self.max_run_step:
- response = self.llm_client.chat.completions.create(model=self.model, messages=messages, tools=tools, **self.generate_cfg)
- message = response.choices[0].message
- messages.append(message)
- logger.debug(f"current step content: {message.content}")
- if message.tool_calls:
- for tool_call in message.tool_calls:
- function_name = tool_call.function.name
- arguments = json.loads(tool_call.function.arguments)
- logger.debug(f"call function[{function_name}], parameter: {arguments}")
- if function_name in self.tool_map:
- result = self.tool_map[function_name](**arguments)
- messages.append({
- "role": "tool",
- "tool_call_id": tool_call.id,
- "content": json.dumps(result, ensure_ascii=False)
- })
- self.tool_call_records.append({
- "name": function_name,
- "arguments": arguments,
- "result": result
- })
- else:
- logger.error(f"Function {function_name} not found in tool map.")
- raise Exception(f"Function {function_name} not found in tool map.")
- else:
- return message.content
- n_steps += 1
- raise Exception("Max run steps exceeded")
|