浏览代码

Update message_push_agent: change func signature

StrayWarrior 2 周之前
父节点
当前提交
164f5fd6fc
共有 1 个文件被更改,包括 6 次插入3 次删除
  1. 6 3
      pqai_agent/agents/message_push_agent.py

+ 6 - 3
pqai_agent/agents/message_push_agent.py

@@ -134,9 +134,11 @@ class MessagePushAgent(SimpleOpenAICompatibleChatAgent):
         ])
         super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
 
-    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> List[Dict]:
+    def generate_message(self, context: Dict, dialogue_history: List[Dict],
+                         query_prompt_template: Optional[str] = None) -> List[Dict]:
         formatted_dialogue = MessagePushAgent.compose_dialogue(dialogue_history)
-        query = QUERY_PROMPT_TEMPLATE.format(**context, dialogue_history=formatted_dialogue)
+        query_prompt_template = query_prompt_template or QUERY_PROMPT_TEMPLATE
+        query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
         self.run(query)
         result = []
         for tool_call in self.tool_call_records:
@@ -164,7 +166,8 @@ class DummyMessagePushAgent(MessagePushAgent):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
 
-    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> List[Dict]:
+    def generate_message(self, context: Dict, dialogue_history: List[Dict],
+                         query_prompt_template: Optional[str] = None) -> List[Dict]:
         logger.debug(f"DummyMessagePushAgent.generate_message called, context: {context}")
         result = [{"type": "text", "content": "测试消息: {agent_name} -> {nickname}".format(**context)},
                   {"type": "image", "content": "https://example.com/test_image.jpg"}]