Browse Source

Update message_push_agent: return multimodal messages

StrayWarrior 2 weeks ago
parent
commit
b4f59c036a
1 changed files with 11 additions and 8 deletions
  1. 11 8
      pqai_agent/agents/message_push_agent.py

+ 11 - 8
pqai_agent/agents/message_push_agent.py

@@ -113,7 +113,7 @@ QUERY_PROMPT_TEMPLATE = """现在,请通过多步思考,以客服的角色
 注意分析客服和用户当前的社交阶段,先确立本次问候的目的。
 注意一定要分析对话信息中的时间,避免和当前时间段不符的内容!注意一定要结合历史的对话情况进行分析和问候方式的选择!
 如有必要,可以使用analyse_image分析用户头像。
-使用message_notify_user发送最终的问候内容,调用时不要传入除了问候内容外的其它任何信息
+使用output_multimodal_message发送最终的消息,如果有多条消息需要发送,可以多次调用output_multimodal_message,请务必保证所有内容都通过output_multimodal_message发出
 如果无需发起问候,可直接结束,无需调用message_notify_user。
 注意每次问候只使用一种话术。
 Now, start to process your task. Please think step by step.
@@ -134,14 +134,15 @@ class MessagePushAgent(SimpleOpenAICompatibleChatAgent):
         ])
         super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
 
-    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> str:
+    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> List[Dict]:
         formatted_dialogue = MessagePushAgent.compose_dialogue(dialogue_history)
         query = QUERY_PROMPT_TEMPLATE.format(**context, dialogue_history=formatted_dialogue)
         self.run(query)
-        for tool_call in reversed(self.tool_call_records):
-            if tool_call['name'] == MessageNotifier.message_notify_user.__name__:
-                return tool_call['arguments']['message']
-        return ''
+        result = []
+        for tool_call in self.tool_call_records:
+            if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
+                result.append(tool_call['arguments']['message'])
+        return result
 
     @staticmethod
     def compose_dialogue(dialogue: List[Dict]) -> str:
@@ -163,6 +164,8 @@ class DummyMessagePushAgent(MessagePushAgent):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
 
-    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> str:
+    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> List[Dict]:
         logger.debug(f"DummyMessagePushAgent.generate_message called, context: {context}")
-        return "测试消息: {agent_name} -> {nickname}".format(**context)
+        result = [{"type": "text", "content": "测试消息: {agent_name} -> {nickname}".format(**context)},
+                  {"type": "image", "content": "https://example.com/test_image.jpg"}]
+        return result