|
@@ -134,9 +134,11 @@ class MessagePushAgent(SimpleOpenAICompatibleChatAgent):
|
|
|
])
|
|
|
super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
|
|
|
|
|
|
- def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> List[Dict]:
|
|
|
+ def generate_message(self, context: Dict, dialogue_history: List[Dict],
|
|
|
+ query_prompt_template: Optional[str] = None) -> List[Dict]:
|
|
|
formatted_dialogue = MessagePushAgent.compose_dialogue(dialogue_history)
|
|
|
- query = QUERY_PROMPT_TEMPLATE.format(**context, dialogue_history=formatted_dialogue)
|
|
|
+ query_prompt_template = query_prompt_template or QUERY_PROMPT_TEMPLATE
|
|
|
+ query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
|
|
|
self.run(query)
|
|
|
result = []
|
|
|
for tool_call in self.tool_call_records:
|
|
@@ -164,7 +166,8 @@ class DummyMessagePushAgent(MessagePushAgent):
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
- def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> List[Dict]:
|
|
|
+ def generate_message(self, context: Dict, dialogue_history: List[Dict],
|
|
|
+ query_prompt_template: Optional[str] = None) -> List[Dict]:
|
|
|
logger.debug(f"DummyMessagePushAgent.generate_message called, context: {context}")
|
|
|
result = [{"type": "text", "content": "测试消息: {agent_name} -> {nickname}".format(**context)},
|
|
|
{"type": "image", "content": "https://example.com/test_image.jpg"}]
|