multimodal_chat_agent.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import datetime
  2. from abc import abstractmethod
  3. from typing import Optional, List, Dict
  4. from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
  5. from pqai_agent.logging_service import logger
  6. from pqai_agent.mq_message import MessageType
  7. from pqai_agent.toolkit.function_tool import FunctionTool
  8. from pqai_agent.toolkit.message_notifier import MessageNotifier
  9. class MultiModalChatAgent(SimpleOpenAICompatibleChatAgent):
  10. """A specialized agent for message reply tasks."""
  11. def __init__(self, model: str, system_prompt: str,
  12. tools: Optional[List[FunctionTool]] = None,
  13. generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
  14. super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
  15. if 'output_multimodal_message' not in self.tool_map:
  16. self.add_tool(FunctionTool(MessageNotifier.output_multimodal_message))
  17. if 'message_notify_user' not in self.tool_map:
  18. self.add_tool(FunctionTool(MessageNotifier.message_notify_user))
  19. @abstractmethod
  20. def generate_message(self, context: Dict, dialogue_history: List[Dict],
  21. query_prompt_template: str) -> List[Dict]:
  22. pass
  23. def _generate_message(self, context: Dict, dialogue_history: List[Dict],
  24. query_prompt_template: str) -> List[Dict]:
  25. formatted_dialogue = MultiModalChatAgent.compose_dialogue(dialogue_history)
  26. query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
  27. self.run(query)
  28. result = []
  29. for tool_call in self.tool_call_records:
  30. if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
  31. result.append(tool_call['arguments']['message'])
  32. return result
  33. @staticmethod
  34. def compose_dialogue(dialogue: List[Dict]) -> str:
  35. role_map = {'user': '用户', 'assistant': '客服'}
  36. messages = []
  37. for msg in dialogue:
  38. if not msg['content']:
  39. continue
  40. if msg['role'] not in role_map:
  41. continue
  42. format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
  43. msg_type = msg.get('type', MessageType.TEXT).description
  44. messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
  45. return '\n'.join(messages)