|
@@ -1,10 +1,8 @@
|
|
|
-import datetime
|
|
|
from typing import Optional, List, Dict
|
|
|
|
|
|
-from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
|
|
|
+from pqai_agent.agents.multimodal_chat_agent import MultiModalChatAgent
|
|
|
from pqai_agent.chat_service import VOLCENGINE_MODEL_DEEPSEEK_V3
|
|
|
from pqai_agent.logging_service import logger
|
|
|
-from pqai_agent.mq_message import MessageType
|
|
|
from pqai_agent.toolkit.function_tool import FunctionTool
|
|
|
from pqai_agent.toolkit.image_describer import ImageDescriber
|
|
|
from pqai_agent.toolkit.message_notifier import MessageNotifier
|
|
@@ -86,7 +84,7 @@ QUERY_PROMPT_TEMPLATE = """现在,请以客服的角色分析以下会话并
|
|
|
Now, start to process your task. Please think step by step.
|
|
|
"""
|
|
|
|
|
|
-class MessageReplyAgent(SimpleOpenAICompatibleChatAgent):
|
|
|
+class MessageReplyAgent(MultiModalChatAgent):
|
|
|
"""A specialized agent for message reply tasks."""
|
|
|
|
|
|
def __init__(self, model: Optional[str] = VOLCENGINE_MODEL_DEEPSEEK_V3, system_prompt: Optional[str] = None,
|
|
@@ -102,29 +100,8 @@ class MessageReplyAgent(SimpleOpenAICompatibleChatAgent):
|
|
|
|
|
|
def generate_message(self, context: Dict, dialogue_history: List[Dict],
|
|
|
query_prompt_template: Optional[str] = None) -> List[Dict]:
|
|
|
- formatted_dialogue = MessageReplyAgent.compose_dialogue(dialogue_history)
|
|
|
query_prompt_template = query_prompt_template or QUERY_PROMPT_TEMPLATE
|
|
|
- query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
|
|
|
- self.run(query)
|
|
|
- result = []
|
|
|
- for tool_call in self.tool_call_records:
|
|
|
- if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
|
|
|
- result.append(tool_call['arguments']['message'])
|
|
|
- return result
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def compose_dialogue(dialogue: List[Dict]) -> str:
|
|
|
- role_map = {'user': '用户', 'assistant': '客服'}
|
|
|
- messages = []
|
|
|
- for msg in dialogue:
|
|
|
- if not msg['content']:
|
|
|
- continue
|
|
|
- if msg['role'] not in role_map:
|
|
|
- continue
|
|
|
- format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
|
|
|
- msg_type = msg.get('type', MessageType.TEXT).description
|
|
|
- messages.append('[{}][{}][{}]{}'.format(role_map[msg['role']], format_dt, msg_type, msg['content']))
|
|
|
- return '\n'.join(messages)
|
|
|
+ return self._generate_message(context, dialogue_history, query_prompt_template)
|
|
|
|
|
|
class DummyMessageReplyAgent(MessageReplyAgent):
|
|
|
"""A dummy agent for testing purposes."""
|
|
@@ -132,7 +109,7 @@ class DummyMessageReplyAgent(MessageReplyAgent):
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
- def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> List[Dict]:
|
|
|
+ def generate_message(self, context: Dict, dialogue_history: List[Dict], query_prompt_template = None) -> List[Dict]:
|
|
|
logger.debug(f"DummyMessageReplyAgent.generate_message called, context: {context}")
|
|
|
result = [{"type": "text", "content": "测试消息: {agent_name} -> {nickname}".format(**context)},
|
|
|
{"type": "image", "content": "https://example.com/test_image.jpg"}]
|