multimodal_chat_agent.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import datetime
  2. from abc import abstractmethod
  3. from typing import Optional, List, Dict
  4. from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
  5. from pqai_agent.logging_service import logger
  6. from pqai_agent.mq_message import MessageType
  7. from pqai_agent.toolkit import get_tool
  8. from pqai_agent.toolkit.function_tool import FunctionTool
  9. from pqai_agent.toolkit.message_notifier import MessageNotifier
  10. class MultiModalChatAgent(SimpleOpenAICompatibleChatAgent):
  11. """A specialized agent for message reply tasks."""
  12. def __init__(self, model: str, system_prompt: str,
  13. tools: Optional[List[FunctionTool]] = None,
  14. generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
  15. super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
  16. if 'output_multimodal_message' not in self.tool_map:
  17. self.add_tool(get_tool('output_multimodal_message'))
  18. if 'message_notify_user' not in self.tool_map:
  19. self.add_tool(get_tool('message_notify_user'))
  20. @abstractmethod
  21. def generate_message(self, context: Dict, dialogue_history: List[Dict],
  22. query_prompt_template: str) -> List[Dict]:
  23. pass
  24. def _generate_message(self, context: Dict, dialogue_history: List[Dict],
  25. query_prompt_template: str) -> List[Dict]:
  26. formatted_dialogue = MultiModalChatAgent.compose_dialogue(dialogue_history)
  27. query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
  28. self.run(query)
  29. result = []
  30. for tool_call in self.tool_call_records:
  31. if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
  32. result.append(tool_call['arguments']['message'])
  33. return result
  34. @staticmethod
  35. def compose_dialogue(dialogue: List[Dict]) -> str:
  36. role_map = {'user': '用户', 'assistant': '客服'}
  37. messages = []
  38. for msg in dialogue:
  39. if not msg['content']:
  40. continue
  41. if msg['role'] not in role_map:
  42. continue
  43. format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
  44. msg_type = msg.get('type', MessageType.TEXT).description
  45. messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
  46. return '\n'.join(messages)