浏览代码

Update message_push_agent and message_reply_agent: update tools parameter

StrayWarrior 1 月之前
父节点
当前提交
7fa34346cb
共有 2 个文件被更改,包括 14 次插入12 次删除
  1. 7 6
      pqai_agent/agents/message_push_agent.py
  2. 7 6
      pqai_agent/agents/message_reply_agent.py

+ 7 - 6
pqai_agent/agents/message_push_agent.py

@@ -127,12 +127,13 @@ class MessagePushAgent(SimpleOpenAICompatibleChatAgent):
                  tools: Optional[List[FunctionTool]] = None,
                  generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
         system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
-        tools = tools or []
-        tools = tools.copy()
-        tools.extend([
-            *ImageDescriber().get_tools(),
-            *MessageNotifier().get_tools(),
-        ])
+        if tools is None:
+            self.tools = [
+                *ImageDescriber().get_tools(),
+                *MessageNotifier().get_tools()
+            ]
+        else:
+            self.tools = [*tools]
         super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
 
     def generate_message(self, context: Dict, dialogue_history: List[Dict],

+ 7 - 6
pqai_agent/agents/message_reply_agent.py

@@ -93,12 +93,13 @@ class MessageReplyAgent(SimpleOpenAICompatibleChatAgent):
                  tools: Optional[List[FunctionTool]] = None,
                  generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
         system_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
-        tools = tools or []
-        tools = tools.copy()
-        tools.extend([
-            *ImageDescriber().get_tools(),
-            *MessageNotifier().get_tools()
-        ])
+        if tools is None:
+            self.tools = [
+                *ImageDescriber().get_tools(),
+                *MessageNotifier().get_tools()
+            ]
+        else:
+            self.tools = [*tools]
         super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
 
     def generate_message(self, context: Dict, dialogue_history: List[Dict],