|
@@ -101,9 +101,11 @@ class MessageReplyAgent(SimpleOpenAICompatibleChatAgent):
|
|
|
])
|
|
|
super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
|
|
|
|
|
|
- def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> List[Dict]:
|
|
|
+ def generate_message(self, context: Dict, dialogue_history: List[Dict],
|
|
|
+ query_prompt_template: Optional[str] = None) -> List[Dict]:
|
|
|
formatted_dialogue = MessageReplyAgent.compose_dialogue(dialogue_history)
|
|
|
- query = QUERY_PROMPT_TEMPLATE.format(**context, dialogue_history=formatted_dialogue)
|
|
|
+ query_prompt_template = query_prompt_template or QUERY_PROMPT_TEMPLATE
|
|
|
+ query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
|
|
|
self.run(query)
|
|
|
result = []
|
|
|
for tool_call in self.tool_call_records:
|