multimodal_chat_agent.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. import datetime
  2. from abc import abstractmethod
  3. from typing import Optional, List, Dict
  4. from pqai_agent import configs
  5. from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
  6. from pqai_agent.logging import logger
  7. from pqai_agent.mq_message import MessageType
  8. from pqai_agent.toolkit import get_tool
  9. from pqai_agent.toolkit.function_tool import FunctionTool
  10. from pqai_agent.toolkit.message_notifier import MessageNotifier
  11. class MultiModalChatAgent(SimpleOpenAICompatibleChatAgent):
  12. """A specialized agent for message reply tasks."""
  13. def __init__(self, model: str, system_prompt: str,
  14. tools: Optional[List[FunctionTool]] = None,
  15. generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
  16. super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
  17. if 'output_multimodal_message' not in self.tool_map:
  18. self.add_tool(get_tool('output_multimodal_message'))
  19. if 'message_notify_user' not in self.tool_map:
  20. self.add_tool(get_tool('message_notify_user'))
  21. @abstractmethod
  22. def generate_message(self, context: Dict, dialogue_history: List[Dict],
  23. query_prompt_template: str) -> List[Dict]:
  24. pass
  25. def _generate_message(self, context: Dict, dialogue_history: List[Dict],
  26. query_prompt_template: str) -> List[Dict]:
  27. if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
  28. return [{'type': 'text', 'content': '测试消息 -> {nickname}'.format(**context)}]
  29. formatted_dialogue = MultiModalChatAgent.compose_dialogue(dialogue_history)
  30. query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
  31. self.run(query)
  32. result = []
  33. for tool_call in self.tool_call_records:
  34. if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
  35. result.append(tool_call['arguments']['message'])
  36. return result
  37. @staticmethod
  38. def compose_dialogue(dialogue: List[Dict]) -> str:
  39. role_map = {'user': '用户', 'assistant': '客服'}
  40. messages = []
  41. for msg in dialogue:
  42. if not msg['content']:
  43. continue
  44. if msg['role'] not in role_map:
  45. continue
  46. format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
  47. msg_type = msg.get('type', MessageType.TEXT).description
  48. messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
  49. return '\n'.join(messages)