1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- import datetime
- from abc import abstractmethod
- from typing import Optional, List, Dict
- from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
- from pqai_agent.logging_service import logger
- from pqai_agent.mq_message import MessageType
- from pqai_agent.toolkit import get_tool
- from pqai_agent.toolkit.function_tool import FunctionTool
- from pqai_agent.toolkit.message_notifier import MessageNotifier
- class MultiModalChatAgent(SimpleOpenAICompatibleChatAgent):
- """A specialized agent for message reply tasks."""
- def __init__(self, model: str, system_prompt: str,
- tools: Optional[List[FunctionTool]] = None,
- generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
- super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
- if 'output_multimodal_message' not in self.tool_map:
- self.add_tool(get_tool('output_multimodal_message'))
- if 'message_notify_user' not in self.tool_map:
- self.add_tool(get_tool('message_notify_user'))
- @abstractmethod
- def generate_message(self, context: Dict, dialogue_history: List[Dict],
- query_prompt_template: str) -> List[Dict]:
- pass
- def _generate_message(self, context: Dict, dialogue_history: List[Dict],
- query_prompt_template: str) -> List[Dict]:
- formatted_dialogue = MultiModalChatAgent.compose_dialogue(dialogue_history)
- query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
- self.run(query)
- result = []
- for tool_call in self.tool_call_records:
- if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
- result.append(tool_call['arguments']['message'])
- return result
- @staticmethod
- def compose_dialogue(dialogue: List[Dict]) -> str:
- role_map = {'user': '用户', 'assistant': '客服'}
- messages = []
- for msg in dialogue:
- if not msg['content']:
- continue
- if msg['role'] not in role_map:
- continue
- format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
- msg_type = msg.get('type', MessageType.TEXT).description
- messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
- return '\n'.join(messages)
|