import datetime from abc import abstractmethod from typing import Optional, List, Dict from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent from pqai_agent.logging_service import logger from pqai_agent.mq_message import MessageType from pqai_agent.toolkit import get_tool from pqai_agent.toolkit.function_tool import FunctionTool from pqai_agent.toolkit.message_notifier import MessageNotifier class MultiModalChatAgent(SimpleOpenAICompatibleChatAgent): """A specialized agent for message reply tasks.""" def __init__(self, model: str, system_prompt: str, tools: Optional[List[FunctionTool]] = None, generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None): super().__init__(model, system_prompt, tools, generate_cfg, max_run_step) if 'output_multimodal_message' not in self.tool_map: self.add_tool(get_tool('output_multimodal_message')) if 'message_notify_user' not in self.tool_map: self.add_tool(get_tool('message_notify_user')) @abstractmethod def generate_message(self, context: Dict, dialogue_history: List[Dict], query_prompt_template: str) -> List[Dict]: pass def _generate_message(self, context: Dict, dialogue_history: List[Dict], query_prompt_template: str) -> List[Dict]: formatted_dialogue = MultiModalChatAgent.compose_dialogue(dialogue_history) query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue) self.run(query) result = [] for tool_call in self.tool_call_records: if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__: result.append(tool_call['arguments']['message']) return result @staticmethod def compose_dialogue(dialogue: List[Dict]) -> str: role_map = {'user': '用户', 'assistant': '客服'} messages = [] for msg in dialogue: if not msg['content']: continue if msg['role'] not in role_map: continue format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S') msg_type = msg.get('type', MessageType.TEXT).description messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content'])) return '\n'.join(messages)