Преглед на файлове

Refactor: add multimodal_chat_agent as base agents for chat and push

StrayWarrior преди 4 дни
родител
ревизия
5777476ce8

+ 3 - 26
pqai_agent/agents/message_push_agent.py

@@ -1,10 +1,8 @@
-import datetime
 from typing import Optional, List, Dict
 
-from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
+from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DEEPSEEK_V3
 from pqai_agent.logging_service import logger
-from pqai_agent.mq_message import MessageType
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.message_notifier import MessageNotifier
@@ -120,7 +118,7 @@ QUERY_PROMPT_TEMPLATE = """现在,请通过多步思考,以客服的角色
 Now, start to process your task. Please think step by step.
  """
 
-class MessagePushAgent(SimpleOpenAICompatibleChatAgent):
+class MessagePushAgent(MultiModalChatAgent):
     """A specialized agent for message push tasks."""
 
     def __init__(self, model: Optional[str] = VOLCENGINE_MODEL_DEEPSEEK_V3, system_prompt: Optional[str] = None,
@@ -136,29 +134,8 @@ class MessagePushAgent(SimpleOpenAICompatibleChatAgent):
 
     def generate_message(self, context: Dict, dialogue_history: List[Dict],
                          query_prompt_template: Optional[str] = None) -> List[Dict]:
-        formatted_dialogue = MessagePushAgent.compose_dialogue(dialogue_history)
         query_prompt_template = query_prompt_template or QUERY_PROMPT_TEMPLATE
-        query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
-        self.run(query)
-        result = []
-        for tool_call in self.tool_call_records:
-            if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
-                result.append(tool_call['arguments']['message'])
-        return result
-
-    @staticmethod
-    def compose_dialogue(dialogue: List[Dict]) -> str:
-        role_map = {'user': '用户', 'assistant': '客服'}
-        messages = []
-        for msg in dialogue:
-            if not msg['content']:
-                continue
-            if msg['role'] not in role_map:
-                continue
-            format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
-            msg_type = msg.get('type', MessageType.TEXT).description
-            messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
-        return '\n'.join(messages)
+        return self._generate_message(context, dialogue_history, query_prompt_template)
 
 class DummyMessagePushAgent(MessagePushAgent):
     """A dummy agent for testing purposes."""

+ 4 - 27
pqai_agent/agents/message_reply_agent.py

@@ -1,10 +1,8 @@
-import datetime
 from typing import Optional, List, Dict
 
-from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
+from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DEEPSEEK_V3
 from pqai_agent.logging_service import logger
-from pqai_agent.mq_message import MessageType
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.message_notifier import MessageNotifier
@@ -86,7 +84,7 @@ QUERY_PROMPT_TEMPLATE = """现在,请以客服的角色分析以下会话并
 Now, start to process your task. Please think step by step.
  """
 
-class MessageReplyAgent(SimpleOpenAICompatibleChatAgent):
+class MessageReplyAgent(MultiModalChatAgent):
     """A specialized agent for message reply tasks."""
 
     def __init__(self, model: Optional[str] = VOLCENGINE_MODEL_DEEPSEEK_V3, system_prompt: Optional[str] = None,
@@ -102,29 +100,8 @@ class MessageReplyAgent(SimpleOpenAICompatibleChatAgent):
 
     def generate_message(self, context: Dict, dialogue_history: List[Dict],
                          query_prompt_template: Optional[str] = None) -> List[Dict]:
-        formatted_dialogue = MessageReplyAgent.compose_dialogue(dialogue_history)
         query_prompt_template = query_prompt_template or QUERY_PROMPT_TEMPLATE
-        query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
-        self.run(query)
-        result = []
-        for tool_call in self.tool_call_records:
-            if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
-                result.append(tool_call['arguments']['message'])
-        return result
-
-    @staticmethod
-    def compose_dialogue(dialogue: List[Dict]) -> str:
-        role_map = {'user': '用户', 'assistant': '客服'}
-        messages = []
-        for msg in dialogue:
-            if not msg['content']:
-                continue
-            if msg['role'] not in role_map:
-                continue
-            format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
-            msg_type = msg.get('type', MessageType.TEXT).description
-            messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
-        return '\n'.join(messages)
+        return self._generate_message(context, dialogue_history, query_prompt_template)
 
 class DummyMessageReplyAgent(MessageReplyAgent):
     """A dummy agent for testing purposes."""
@@ -132,7 +109,7 @@ class DummyMessageReplyAgent(MessageReplyAgent):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
 
-    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> List[Dict]:
+    def generate_message(self, context: Dict, dialogue_history: List[Dict], query_prompt_template = None) -> List[Dict]:
         logger.debug(f"DummyMessageReplyAgent.generate_message called, context: {context}")
         result = [{"type": "text", "content": "测试消息: {agent_name} -> {nickname}".format(**context)},
                   {"type": "image", "content": "https://example.com/test_image.jpg"}]

+ 52 - 0
pqai_agent/agents/multimodal_chat_agent.py

@@ -0,0 +1,52 @@
+import datetime
+from abc import abstractmethod
+from typing import Optional, List, Dict
+
+from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
+from pqai_agent.logging_service import logger
+from pqai_agent.mq_message import MessageType
+from pqai_agent.toolkit.function_tool import FunctionTool
+from pqai_agent.toolkit.message_notifier import MessageNotifier
+
+
+class MultiModalChatAgent(SimpleOpenAICompatibleChatAgent):
+    """A specialized agent for message reply tasks."""
+
+    def __init__(self, model: str, system_prompt: str,
+                 tools: Optional[List[FunctionTool]] = None,
+                 generate_cfg: Optional[dict] = None, max_run_step: Optional[int] = None):
+        super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
+        if 'output_multimodal_message' not in self.tool_map:
+            self.add_tool(FunctionTool(MessageNotifier.output_multimodal_message))
+        if 'message_notify_user' not in self.tool_map:
+            self.add_tool(FunctionTool(MessageNotifier.message_notify_user))
+
+    @abstractmethod
+    def generate_message(self, context: Dict, dialogue_history: List[Dict],
+                         query_prompt_template: str) -> List[Dict]:
+        pass
+
+    def _generate_message(self, context: Dict, dialogue_history: List[Dict],
+                         query_prompt_template: str) -> List[Dict]:
+        formatted_dialogue = MultiModalChatAgent.compose_dialogue(dialogue_history)
+        query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
+        self.run(query)
+        result = []
+        for tool_call in self.tool_call_records:
+            if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
+                result.append(tool_call['arguments']['message'])
+        return result
+
+    @staticmethod
+    def compose_dialogue(dialogue: List[Dict]) -> str:
+        role_map = {'user': '用户', 'assistant': '客服'}
+        messages = []
+        for msg in dialogue:
+            if not msg['content']:
+                continue
+            if msg['role'] not in role_map:
+                continue
+            format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
+            msg_type = msg.get('type', MessageType.TEXT).description
+            messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
+        return '\n'.join(messages)

+ 7 - 0
pqai_agent/agents/simple_chat_agent.py

@@ -24,6 +24,13 @@ class SimpleOpenAICompatibleChatAgent:
         self.max_run_step = max_run_step or DEFAULT_MAX_RUN_STEPS
         self.tool_call_records = []
 
+    def add_tool(self, tool: FunctionTool):
+        """添加一个工具到Agent中"""
+        if tool.name in self.tool_map:
+            logger.warning(f"Tool {tool.name} already exists, replacing it.")
+        self.tools.append(tool)
+        self.tool_map[tool.name] = tool
+
     def run(self, user_input: str) -> str:
         messages = [{"role": "system", "content": self.system_prompt}]
         tools = [tool.get_openai_tool_schema() for tool in self.tools]