Explorar o código

Update message_push_agent: change arguments

StrayWarrior hai 1 mes
pai
achega
0830c93863
Modificáronse 1 ficheiros con 14 adicións e 3 borrados
  1. 14 3
      pqai_agent/agents/message_push_agent.py

+ 14 - 3
pqai_agent/agents/message_push_agent.py

@@ -123,9 +123,9 @@ class MessagePushAgent(SimpleOpenAICompatibleChatAgent):
         ])
         super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
 
-    def generate_message(self, user_profile: Dict, context: Dict, dialogue_history: List[Dict]) -> str:
+    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> str:
         formatted_dialogue = MessagePushAgent.compose_dialogue(dialogue_history)
-        query = QUERY_PROMPT_TEMPLATE.format(**user_profile, **context, dialogue_history=formatted_dialogue)
+        query = QUERY_PROMPT_TEMPLATE.format(**context, dialogue_history=formatted_dialogue)
         self.run(query)
         for tool_call in reversed(self.tool_call_records):
             if tool_call['name'] == MessageNotifier.message_notify_user.__name__:
@@ -145,6 +145,16 @@ class MessagePushAgent(SimpleOpenAICompatibleChatAgent):
             messages.append('[{}][{}]{}'.format(role_map[msg['role']], format_dt, msg['content']))
         return '\n'.join(messages)
 
+class DummyMessagePushAgent(MessagePushAgent):
+    """A dummy agent for testing purposes."""
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> str:
+        return "测试消息: {agent_name} -> {nickname}".format(**context)
+
+
 if __name__ == '__main__':
     import pqai_agent.logging_service
     pqai_agent.logging_service.setup_root_logger()
@@ -161,13 +171,14 @@ if __name__ == '__main__':
     }
     test_context = {
         "current_datetime": "2025-05-10 08:00:00",
+        **test_user_profile
     }
     def create_ts(year, month, day, hour, minute):
         return datetime.datetime(year, month, day, hour, minute).timestamp() * 1000
     messages = [
         # {"role": "assistant", "content": "月哥,早上好!看到您的头像是一片宁静的户外风景,感觉您一定很喜欢大自然吧?今天天气不错,您有什么计划吗?", "timestamp": create_ts(2025, 5, 10, 8, 0)},
     ]
-    response = agent.generate_message(test_user_profile, test_context, messages)
+    response = agent.generate_message(test_context, messages)
     print(response)