Ver código fonte

Merge branch 'feature/202505-multimodal' of Server/AgentCoreService into master

fengzhoutian 1 semana atrás
pai
commit
3a3b153f2f

+ 72 - 88
pqai_agent/agent_service.py

@@ -6,7 +6,7 @@ import re
 import signal
 import sys
 import time
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Union
 import logging
 from datetime import datetime, timedelta
 import threading
@@ -18,6 +18,7 @@ from apscheduler.schedulers.background import BackgroundScheduler
 from sqlalchemy.orm import sessionmaker
 
 from pqai_agent import configs
+from pqai_agent.agents.message_reply_agent import MessageReplyAgent
 from pqai_agent.configs import apollo_config
 from pqai_agent.exceptions import NoRetryException
 from pqai_agent.logging_service import logger
@@ -31,7 +32,7 @@ from pqai_agent.response_type_detector import ResponseTypeDetector
 from pqai_agent.user_manager import UserManager, UserRelationManager
 from pqai_agent.message_queue_backend import MessageQueueBackend, AliyunRocketMQQueueBackend
 from pqai_agent.user_profile_extractor import UserProfileExtractor
-from pqai_agent.message import MessageType, Message, MessageChannel
+from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
 from pqai_agent.utils.db_utils import create_sql_engine
 
 
@@ -136,7 +137,7 @@ class AgentService:
             time.sleep(1)
         logger.info("Scheduler event processing thread exit")
 
-    def process_scheduler_event(self, msg: Message):
+    def process_scheduler_event(self, msg: MqMessage):
         if msg.type == MessageType.AGGREGATION_TRIGGER:
             # 延迟触发的消息,需放入接收队列以驱动Agent运转
             self.receive_queue.produce(msg)
@@ -148,7 +149,7 @@ class AgentService:
         agent_key = 'agent_{}_{}'.format(staff_id, user_id)
         if agent_key not in self.agent_registry:
             self.agent_registry[agent_key] = DialogueManager(
-                staff_id, user_id, self.user_manager, self.agent_state_cache)
+                staff_id, user_id, self.user_manager, self.agent_state_cache, self.AgentDBSession)
         agent = self.agent_registry[agent_key]
         agent.refresh_profile()
         return agent
@@ -190,7 +191,7 @@ class AgentService:
                     logger.error("Error processing message: {}, {}".format(e, error_stack))
             time.sleep(0.1)
         receive_queue.shutdown()
-        logger.info("Message processing thread exit")
+        logger.info("MqMessage processing thread exit")
 
     def start(self, blocking=False):
         self.running = True
@@ -255,7 +256,7 @@ class AgentService:
     def _schedule_aggregation_trigger(self, staff_id: str, user_id: str, delay_sec: int):
         logger.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
         message_ts = int((time.time() + delay_sec) * 1000)
-        msg = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
+        msg = MqMessage.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
         # 系统消息使用特定的msgId,无实际意义
         msg.msgId = -MessageType.AGGREGATION_TRIGGER.value
         if self.scheduler_mode == 'mq':
@@ -265,7 +266,7 @@ class AgentService:
                                    'date',
                                    run_date=datetime.now() + timedelta(seconds=delay_sec))
 
-    def process_single_message(self, message: Message):
+    def process_single_message(self, message: MqMessage):
         user_id = message.sender
         staff_id = message.receiver
 
@@ -293,16 +294,8 @@ class AgentService:
             elif need_response:
                 # 先更新用户画像再处理回复
                 self._update_user_profile(user_id, user_profile, agent.dialogue_history[-10:])
-                resp = self._get_chat_response(user_id, agent, message_text)
-                if resp:
-                    recent_dialogue = agent.dialogue_history[-10:]
-                    agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
-                    if len(recent_dialogue) < 2 or staff_id not in agent_voice_whitelist:
-                        message_type = MessageType.TEXT
-                    else:
-                        message_type = self.response_type_detector.detect_type(
-                            recent_dialogue[:-1], recent_dialogue[-1], enable_random=True)
-                    self.send_response(staff_id, user_id, resp, message_type)
+                resp = self.get_chat_response(agent, message_text)
+                self.send_responses(agent, resp)
             else:
                 logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
             # 当前消息处理成功,commit并持久化agent状态
@@ -311,27 +304,53 @@ class AgentService:
             agent.rollback_state()
             raise e
 
-    def send_response(self, staff_id, user_id, response, message_type: MessageType, skip_check=False):
-        logger.warning(f"staff[{staff_id}] user[{user_id}]: response[{message_type}] {response}")
-        current_ts = int(time.time() * 1000)
+    def send_responses(self, agent: DialogueManager, contents: List[Dict]):
+        staff_id = agent.staff_id
+        user_id = agent.user_id
+        recent_dialogue = agent.dialogue_history[-10:]
+        agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
+        for item in contents:
+            if item["type"] == MessageType.TEXT:
+                if staff_id in agent_voice_whitelist:
+                    message_type = self.response_type_detector.detect_type(
+                        recent_dialogue, item["content"], enable_random=True)
+                    item["type"] = message_type
+        if contents:
+            current_ts = int(time.time())
+            for response in contents:
+                self.send_multimodal_response(staff_id, user_id, response, skip_check=True)
+            agent.update_last_active_interaction_time(current_ts)
+        else:
+            logger.debug(f"staff[{staff_id}], user[{user_id}]: no messages to send")
+
+    def can_send_to_user(self, staff_id, user_id) -> bool:
         user_tags = self.user_relation_manager.get_user_tags(user_id)
         white_list_tags = set(apollo_config.get_json_value("agent_response_whitelist_tags", []))
         hit_white_list_tags = len(set(user_tags).intersection(white_list_tags)) > 0
-        # FIXME(zhoutian)
-        # 测试期间临时逻辑,只发送特定的账号或特定用户
         staff_white_lists = set(apollo_config.get_json_value("agent_response_whitelist_staffs", []))
-        if not (staff_id in staff_white_lists or hit_white_list_tags or skip_check):
+        if not (staff_id in staff_white_lists or hit_white_list_tags):
             logger.warning(f"staff[{staff_id}] user[{user_id}]: skip reply")
+            return False
+        return True
+
+    def send_multimodal_response(self, staff_id, user_id, response: Dict, skip_check=False):
+        message_type = response["type"]
+        logger.warning(f"staff[{staff_id}] user[{user_id}]: response[{message_type}] {response}")
+        if message_type not in (MessageType.TEXT, MessageType.IMAGE_QW, MessageType.VOICE):
+            logger.error(f"staff[{staff_id}] user[{user_id}]: unsupported message type {message_type}")
+            return
+        if not skip_check and not self.can_send_to_user(staff_id, user_id):
             return
+        current_ts = int(time.time() * 1000)
         self.send_rate_limiter.wait_for_sending(staff_id, response)
         self.send_queue.produce(
-            Message.build(message_type, MessageChannel.CORP_WECHAT,
-                          staff_id, user_id, response, current_ts)
+            MqMessage.build(message_type, MessageChannel.CORP_WECHAT,
+                            staff_id, user_id, response["content"], current_ts)
         )
 
-    def _route_to_human_intervention(self, user_id: str, origin_message: Message):
+    def _route_to_human_intervention(self, user_id: str, origin_message: MqMessage):
         """路由到人工干预"""
-        self.human_queue.produce(Message.build(
+        self.human_queue.produce(MqMessage.build(
             MessageType.TEXT,
             origin_message.channel,
             origin_message.sender,
@@ -392,68 +411,15 @@ class AgentService:
         # 问题在于,如果每次创建出新的PushTaskWorkerPool,在上次任务有未处理完的消息即退出时,会有未处理的消息堆积
         push_task_worker_pool.wait_to_finish()
 
-    def _check_initiative_conversations_v1(self):
-        logger.info("start to check initiative conversations")
-        if not DialogueManager.is_time_suitable_for_active_conversation():
-            logger.info("time is not suitable for active conversation")
-            return
-        white_list_tags = set(apollo_config.get_json_value('agent_initiate_whitelist_tags', []))
-        first_initiate_tags = set(apollo_config.get_json_value('agent_first_initiate_whitelist_tags', []))
-        # 合并白名单,减少配置成本
-        white_list_tags.update(first_initiate_tags)
-        voice_tags = set(apollo_config.get_json_value('agent_initiate_by_voice_tags', []))
-
-
-        """定时检查主动发起对话"""
-        for staff_user in self.user_relation_manager.list_staff_users():
-            staff_id = staff_user['staff_id']
-            user_id = staff_user['user_id']
-            agent = self.get_agent_instance(staff_id, user_id)
-            should_initiate = agent.should_initiate_conversation()
-            user_tags = self.user_relation_manager.get_user_tags(user_id)
-
-            if configs.get_env() != 'dev' and not white_list_tags.intersection(user_tags):
-                should_initiate = False
-
-            if should_initiate:
-                logger.warning(f"user[{user_id}], tags{user_tags}: initiate conversation")
-                # FIXME:虽然需要主动唤起的用户同时发来消息的概率很低,但仍可能会有并发冲突 需要并入事件驱动框架
-                agent.do_state_change(DialogueState.GREETING)
-                try:
-                    if agent.previous_state == DialogueState.INITIALIZED or first_initiate_tags.intersection(user_tags):
-                        # 完全无交互历史的用户才使用此策略,但新用户接入即会产生“我已添加了你”的消息将Agent初始化
-                        # 因此存量用户无法使用该状态做实验
-                        # TODO:增加基于对话历史的判断、策略去重;如果对话间隔过长需要使用长期记忆检索;在无长期记忆时,可采用用户添加时间来判断
-                        resp = self._generate_active_greeting_message(agent, user_tags)
-                    else:
-                        resp = self._get_chat_response(user_id, agent, None)
-                    if resp:
-                        if set(user_tags).intersection(voice_tags):
-                            message_type = MessageType.VOICE
-                        else:
-                            message_type = MessageType.TEXT
-                        self.send_response(staff_id, user_id, resp, message_type, skip_check=True)
-                    agent.persist_state()
-                except Exception as e:
-                    # FIXME:虽然需要主动唤起的用户同时发来消息的概率很低,但仍可能会有并发冲突
-                    agent.rollback_state()
-                    logger.error("Error in active greeting: {}".format(e))
-            else:
-                logger.debug(f"user[{user_id}], do not initiate conversation")
-
-    def _generate_active_greeting_message(self, agent: DialogueManager, user_tags: List[str]=None):
-        chat_config = agent.build_active_greeting_config(user_tags)
-        chat_response = self._call_chat_api(chat_config, ChatServiceType.OPENAI_COMPATIBLE)
-        chat_response = self.sanitize_response(chat_response)
-        if response := agent.generate_response(chat_response):
-            return response
+    def get_chat_response(self, agent: DialogueManager, user_message: Optional[str]) -> List[Dict]:
+        chat_agent_ver = self.config.get('system', {}).get('chat_agent_version', 1)
+        if chat_agent_ver == 2:
+            return self._get_chat_response_v2(agent)
         else:
-            logger.warning(f"staff[{agent.staff_id}] user[{agent.user_id}]: no response generated")
-            return None
+            text_resp = self._get_chat_response_v1(agent, user_message)
+            return [{"type": MessageType.TEXT, "content": text_resp}] if text_resp else []
 
-    def _get_chat_response(self, user_id: str, agent: DialogueManager,
-                           user_message: Optional[str]):
-        """处理LLM响应"""
+    def _get_chat_response_v1(self, agent: DialogueManager, user_message: Optional[str]) -> Optional[str]:
         chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
         config_for_logging = chat_config.copy()
         config_for_logging['messages'] = config_for_logging['messages'][-20:]
@@ -464,9 +430,27 @@ class AgentService:
         if response := agent.generate_response(chat_response):
             return response
         else:
-            logger.warning(f"staff[{agent.staff_id}] user[{user_id}]: no response generated")
+            logger.warning(f"staff[{agent.staff_id}] user[{agent.user_id}]: no response generated")
             return None
 
+    def _get_chat_response_v2(self, main_agent: DialogueManager) -> List[Dict]:
+        chat_agent = MessageReplyAgent()
+        chat_responses = chat_agent.generate_message(
+            context=main_agent.get_prompt_context(None),
+            dialogue_history=main_agent.dialogue_history[-100:]
+        )
+        if not chat_responses:
+            logger.warning(f"staff[{main_agent.staff_id}] user[{main_agent.user_id}]: no response generated")
+            return []
+        final_responses = []
+        for chat_response in chat_responses:
+            if response := main_agent.generate_multimodal_response(chat_response):
+                final_responses.append(response)
+            else:
+                # 存在非法/结束消息,清空待发消息
+                final_responses.clear()
+        return final_responses
+
     def _call_chat_api(self, chat_config: Dict, chat_service_type: ChatServiceType) -> str:
         if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
             return 'LLM模拟回复 {}'.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

+ 16 - 10
pqai_agent/agents/message_push_agent.py

@@ -4,7 +4,7 @@ from typing import Optional, List, Dict
 from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DEEPSEEK_V3
 from pqai_agent.logging_service import logger
-from pqai_agent.message import MessageType
+from pqai_agent.mq_message import MessageType
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.message_notifier import MessageNotifier
@@ -113,7 +113,7 @@ QUERY_PROMPT_TEMPLATE = """现在,请通过多步思考,以客服的角色
 注意分析客服和用户当前的社交阶段,先确立本次问候的目的。
 注意一定要分析对话信息中的时间,避免和当前时间段不符的内容!注意一定要结合历史的对话情况进行分析和问候方式的选择!
 如有必要,可以使用analyse_image分析用户头像。
-使用message_notify_user发送最终的问候内容,调用时不要传入除了问候内容外的其它任何信息
+使用output_multimodal_message发送最终的消息,如果有多条消息需要发送,可以多次调用output_multimodal_message,请务必保证所有内容都通过output_multimodal_message发出
 如果无需发起问候,可直接结束,无需调用message_notify_user。
 注意每次问候只使用一种话术。
 Now, start to process your task. Please think step by step.
@@ -134,14 +134,17 @@ class MessagePushAgent(SimpleOpenAICompatibleChatAgent):
         ])
         super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
 
-    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> str:
+    def generate_message(self, context: Dict, dialogue_history: List[Dict],
+                         query_prompt_template: Optional[str] = None) -> List[Dict]:
         formatted_dialogue = MessagePushAgent.compose_dialogue(dialogue_history)
-        query = QUERY_PROMPT_TEMPLATE.format(**context, dialogue_history=formatted_dialogue)
+        query_prompt_template = query_prompt_template or QUERY_PROMPT_TEMPLATE
+        query = query_prompt_template.format(**context, dialogue_history=formatted_dialogue)
         self.run(query)
-        for tool_call in reversed(self.tool_call_records):
-            if tool_call['name'] == MessageNotifier.message_notify_user.__name__:
-                return tool_call['arguments']['message']
-        return ''
+        result = []
+        for tool_call in self.tool_call_records:
+            if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
+                result.append(tool_call['arguments']['message'])
+        return result
 
     @staticmethod
     def compose_dialogue(dialogue: List[Dict]) -> str:
@@ -163,6 +166,9 @@ class DummyMessagePushAgent(MessagePushAgent):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
 
-    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> str:
+    def generate_message(self, context: Dict, dialogue_history: List[Dict],
+                         query_prompt_template: Optional[str] = None) -> List[Dict]:
         logger.debug(f"DummyMessagePushAgent.generate_message called, context: {context}")
-        return "测试消息: {agent_name} -> {nickname}".format(**context)
+        result = [{"type": "text", "content": "测试消息: {agent_name} -> {nickname}".format(**context)},
+                  {"type": "image", "content": "https://example.com/test_image.jpg"}]
+        return result

+ 14 - 11
pqai_agent/agents/message_reply_agent.py

@@ -4,7 +4,7 @@ from typing import Optional, List, Dict
 from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
 from pqai_agent.chat_service import VOLCENGINE_MODEL_DEEPSEEK_V3
 from pqai_agent.logging_service import logger
-from pqai_agent.message import MessageType
+from pqai_agent.mq_message import MessageType
 from pqai_agent.toolkit.function_tool import FunctionTool
 from pqai_agent.toolkit.image_describer import ImageDescriber
 from pqai_agent.toolkit.message_notifier import MessageNotifier
@@ -79,10 +79,10 @@ QUERY_PROMPT_TEMPLATE = """现在,请以客服的角色分析以下会话并
 注意对话信息的格式为: [角色][时间][消息类型]对话内容
 注意分析客服和用户当前的社交阶段,先确立对话的目的。
 注意一定要分析对话信息中的时间,避免和当前时间段不符的内容!注意一定要结合历史的对话情况进行分析和问候方式的选择!
-使用message_notify_user发送最终的回复内容,调用时不要传入除了回复内容外的其它任何信息
+请调用output_multimodal_message工具发送最终的消息,如果有多条消息需要发送,可以多次调用output_multimodal_message,请务必保证所有内容都通过output_multimodal_message发出
 请注意这是微信聊天,如果用户使用了表情包,请使用analyse_image描述表情包,并分析其含义和情绪,如果要回复请尽量用简短的emoji或文字进行回复。
-如果用户连续2次以上感到疑惑,请先发送<人工介入>,后接你认为需要人工介入的原因。如果判断对话可自然结束、无需再回复用户,请发送<结束>。如果用户表现出强烈的负向情绪、要求不再对话,请发送<负向情绪结束>。
-以上特殊消息的发送请使用message_notify_user。
+特殊情况:如果用户连续2次以上感到疑惑,请先发送<人工介入>,后接你认为需要人工介入的原因。如果判断对话可自然结束、无需再回复用户,请发送<结束>。如果用户表现出强烈的负向情绪、要求不再对话,请发送<负向情绪结束>。
+以上特殊消息的发送请使用message_notify_user。
 Now, start to process your task. Please think step by step.
  """
 
@@ -101,14 +101,15 @@ class MessageReplyAgent(SimpleOpenAICompatibleChatAgent):
         ])
         super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
 
-    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> str:
+    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> List[Dict]:
         formatted_dialogue = MessageReplyAgent.compose_dialogue(dialogue_history)
         query = QUERY_PROMPT_TEMPLATE.format(**context, dialogue_history=formatted_dialogue)
         self.run(query)
-        for tool_call in reversed(self.tool_call_records):
-            if tool_call['name'] == MessageNotifier.message_notify_user.__name__:
-                return tool_call['arguments']['message']
-        return ''
+        result = []
+        for tool_call in self.tool_call_records:
+            if tool_call['name'] == MessageNotifier.output_multimodal_message.__name__:
+                result.append(tool_call['arguments']['message'])
+        return result
 
     @staticmethod
     def compose_dialogue(dialogue: List[Dict]) -> str:
@@ -130,6 +131,8 @@ class DummyMessageReplyAgent(MessageReplyAgent):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
 
-    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> str:
+    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> List[Dict]:
         logger.debug(f"DummyMessageReplyAgent.generate_message called, context: {context}")
-        return "测试消息: {agent_name} -> {nickname}".format(**context)
+        result = [{"type": "text", "content": "测试消息: {agent_name} -> {nickname}".format(**context)},
+                  {"type": "image", "content": "https://example.com/test_image.jpg"}]
+        return result

+ 2 - 26
pqai_agent/agents/simple_chat_agent.py

@@ -2,11 +2,10 @@ import json
 from typing import List, Optional
 
 from pqai_agent.agent import DEFAULT_MAX_RUN_STEPS
-from pqai_agent.chat_service import OpenAICompatible, VOLCENGINE_MODEL_DEEPSEEK_V3
+from pqai_agent.chat_service import OpenAICompatible
 from pqai_agent.logging_service import logger
 from pqai_agent.toolkit.function_tool import FunctionTool
-from pqai_agent.toolkit.image_describer import ImageDescriber
-from pqai_agent.toolkit.message_notifier import MessageNotifier
+
 
 
 class SimpleOpenAICompatibleChatAgent:
@@ -61,26 +60,3 @@ class SimpleOpenAICompatibleChatAgent:
             n_steps += 1
 
         raise Exception("Max run steps exceeded")
-
-if __name__ == '__main__':
-    import pqai_agent.logging_service
-    pqai_agent.logging_service.setup_root_logger()
-    tools = [
-        *ImageDescriber().get_tools(),
-        *MessageNotifier().get_tools()
-    ]
-    system_instruction = "You are a helpful assistant."
-    agent = SimpleOpenAICompatibleChatAgent(
-        model=VOLCENGINE_MODEL_DEEPSEEK_V3,
-        system_prompt=system_instruction,
-        tools=tools
-    )
-
-    user_input = query = """
-分析以下图片的内容:"http://wx.qlogo.cn/mmhead/Q3auHgzwzM5glpnBtDUianJErYf9AQsptLM3N78xP3sOR8SSibsG35HQ/0"
-根据内容联想作一首诗
-Please think step by step.
- """
-
-    result = agent.run(user_input)
-    print(result)

+ 2 - 0
pqai_agent/configs/dev.yaml

@@ -34,6 +34,8 @@ storage:
     table: agent_state
   chat_history:
     table: qywx_chat_history
+  push_record:
+    table: agent_push_record_dev
 
 agent_behavior:
   message_aggregation_sec: 3

+ 2 - 0
pqai_agent/configs/prod.yaml

@@ -34,6 +34,8 @@ storage:
     table: agent_state
   chat_history:
     table: qywx_chat_history
+  push_record:
+    table: agent_push_record_dev
 
 chat_api:
   coze:

+ 78 - 14
pqai_agent/dialogue_manager.py

@@ -11,15 +11,17 @@ import textwrap
 import pymysql.cursors
 
 import cozepy
+from sqlalchemy.orm import sessionmaker, Session
 
 from pqai_agent import configs
+from pqai_agent.data_models.agent_push_record import AgentPushRecord
 from pqai_agent.logging_service import logger
 from pqai_agent.database import MySQLManager
 from pqai_agent import chat_service, prompt_templates
 from pqai_agent.history_dialogue_service import HistoryDialogueService
 
 from pqai_agent.chat_service import ChatServiceType
-from pqai_agent.message import MessageType, Message
+from pqai_agent.mq_message import MessageType, MqMessage
 from pqai_agent.toolkit.lark_alert_for_human_intervention import LarkAlertForHumanIntervention
 from pqai_agent.toolkit.lark_sheet_record_for_human_intervention import LarkSheetRecordForHumanIntervention
 from pqai_agent.user_manager import UserManager
@@ -99,7 +101,8 @@ class DialogueStateCache:
                       .format(staff_id, user_id, state, previous_state, rows))
 
 class DialogueManager:
-    def __init__(self, staff_id: str, user_id: str, user_manager: UserManager, state_cache: DialogueStateCache):
+    def __init__(self, staff_id: str, user_id: str, user_manager: UserManager, state_cache: DialogueStateCache,
+                 AgentDBSession: sessionmaker[Session]):
         config = configs.get()
 
         self.staff_id = staff_id
@@ -113,7 +116,8 @@ class DialogueManager:
         self.user_profile = self.user_manager.get_user_profile(user_id)
         self.staff_profile = self.user_manager.get_staff_profile(staff_id)
         # FIXME: 交互时间和对话记录都涉及到回滚
-        self.last_interaction_time = 0
+        self.last_interaction_time_ms = 0
+        self.last_active_interaction_time_sec = 0
         self.human_intervention_triggered = False
         self.vector_memory = DummyVectorMemoryManager(user_id)
         self.message_aggregation_sec = config.get('agent_behavior', {}).get('message_aggregation_sec', 5)
@@ -121,6 +125,7 @@ class DialogueManager:
         self.history_dialogue_service = HistoryDialogueService(
             config['storage']['history_dialogue']['api_base_url']
         )
+        self.AgentDBSession = AgentDBSession
         self._recover_state()
         # 由于本地状态管理过于复杂,引入事务机制做状态回滚
         self._uncommited_state_change = []
@@ -159,7 +164,7 @@ class DialogueManager:
         self.dialogue_history = self.history_dialogue_service.get_dialogue_history(
             self.staff_id, self.user_id, minutes_to_get)
         if self.dialogue_history:
-            self.last_interaction_time = self.dialogue_history[-1]['timestamp']
+            self.last_interaction_time_ms = self.dialogue_history[-1]['timestamp']
             if self.current_state == DialogueState.MESSAGE_AGGREGATING:
                 # 需要恢复未处理对话,找到dialogue_history中最后未处理的user消息
                 for entry in reversed(self.dialogue_history):
@@ -168,17 +173,25 @@ class DialogueManager:
                         break
         else:
             # 默认设置
-            self.last_interaction_time = int(time.time() * 1000) - minutes_to_get * 60 * 1000
-        time_for_read = datetime.fromtimestamp(self.last_interaction_time / 1000).strftime("%Y-%m-%d %H:%M:%S")
-        logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: state: {self.current_state.name}, last_interaction: {time_for_read}")
+            self.last_interaction_time_ms = int(time.time() * 1000) - minutes_to_get * 60 * 1000
+        with self.AgentDBSession() as session:
+            # 读取数据库中的最后一次交互时间
+            query = session.query(AgentPushRecord).filter(
+                AgentPushRecord.staff_id == self.staff_id,
+                AgentPushRecord.user_id == self.user_id
+            ).order_by(AgentPushRecord.timestamp.desc()).first()
+            if query:
+                self.last_active_interaction_time_sec = query.timestamp
+        fmt_time = datetime.fromtimestamp(self.last_interaction_time_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
+        logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: state: {self.current_state.name}, last_interaction: {fmt_time}")
 
     def update_interaction_time(self, timestamp_ms: int):
         self._uncommited_state_change.append(DialogueStateChange(
             DialogueStateChangeType.INTERACTION_TIME,
-            self.last_interaction_time,
+            self.last_interaction_time_ms,
             timestamp_ms
         ))
-        self.last_interaction_time = timestamp_ms
+        self.last_interaction_time_ms = timestamp_ms
 
     def append_dialogue_history(self, message: Dict):
         self._uncommited_state_change.append(DialogueStateChange(
@@ -202,7 +215,7 @@ class DialogueManager:
             if entry.event_type == DialogueStateChangeType.STATE:
                 self.current_state, self.previous_state = entry.old
             elif entry.event_type == DialogueStateChangeType.INTERACTION_TIME:
-                self.last_interaction_time = entry.old
+                self.last_interaction_time_ms = entry.old
             elif entry.event_type == DialogueStateChangeType.DIALOGUE_HISTORY:
                 self.dialogue_history.pop()
             else:
@@ -226,7 +239,7 @@ class DialogueManager:
             (self.current_state, self.previous_state)
         ))
 
-    def update_state(self, message: Message) -> Tuple[bool, Optional[str]]:
+    def update_state(self, message: MqMessage) -> Tuple[bool, Optional[str]]:
         """根据用户消息更新对话状态,并返回是否需要发起回复 及下一条需处理的用户消息"""
         message_text = message.content
         message_ts = message.sendTime
@@ -255,7 +268,7 @@ class DialogueManager:
         if self.current_state == DialogueState.MESSAGE_AGGREGATING:
             # 收到的是特殊定时触发的空消息,且在聚合中,且已经超时,继续处理
             if message.type == MessageType.AGGREGATION_TRIGGER:
-                if message_ts - self.last_interaction_time > self.message_aggregation_sec * 1000:
+                if message_ts - self.last_interaction_time_ms > self.message_aggregation_sec * 1000:
                     logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: exit aggregation waiting")
                 else:
                     logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: continue aggregation waiting")
@@ -409,6 +422,7 @@ class DialogueManager:
         message_ts = int(time.time() * 1000)
         self.append_dialogue_history({
             "role": "assistant",
+            "type": MessageType.TEXT,
             "content": llm_response,
             "timestamp": message_ts,
             "state": self.current_state.name
@@ -417,13 +431,59 @@ class DialogueManager:
 
         return llm_response
 
+    def generate_multimodal_response(self, item: Dict) -> Optional[Dict]:
+        """
+        处理LLM的多模态响应,更新对话状态和对话历史。
+        注意:所有的LLM多模态响应都必须经过这个函数来处理!
+        :param item: 包含多模态内容的字典
+        :return: None
+        """
+        if self.current_state == DialogueState.HUMAN_INTERVENTION:
+            return None
+
+        raw_type = item.get("type", "text")
+        if isinstance(raw_type, str):
+            item["type"] = MessageType.from_str(raw_type)
+
+        if item["type"] == MessageType.TEXT:
+            if '<人工介入>' in item["content"]:
+                reason = item["content"].replace('<人工介入>', '')
+                logger.warning(f'staff[{self.staff_id}], user[{self.user_id}]: human intervention triggered, reason: {reason}')
+                self.do_state_change(DialogueState.HUMAN_INTERVENTION)
+                self._send_alert('人工介入', reason)
+                return None
+
+            if '<结束>' in item["content"] or '<负向情绪结束>' in item["content"]:
+                logger.warning(f'staff[{self.staff_id}], user[{self.user_id}]: conversation ended')
+                self.do_state_change(DialogueState.FAREWELL)
+                if '<负向情绪结束>' in item["content"]:
+                    self._send_alert("用户负向情绪")
+                return None
+
+        # 记录响应到对话历史
+        message_ts = int(time.time() * 1000)
+        self.append_dialogue_history({
+            "role": "assistant",
+            "type": item["type"],
+            "content": item["content"],
+            "timestamp": message_ts,
+            "state": self.current_state.name
+        })
+        self.update_interaction_time(message_ts)
+
+        return item
+
     def _get_hours_since_last_interaction(self, precision: int = -1):
-        time_diff = (time.time() * 1000) - self.last_interaction_time
+        time_diff = (time.time() * 1000) - self.last_interaction_time_ms
         hours_passed = time_diff / 1000 / 3600
         if precision >= 0:
             return round(hours_passed, precision)
         return hours_passed
 
+    def update_last_active_interaction_time(self, timestamp_sec: int):
+        # 只需更新本地时间,重启时可从数据库恢复
+        self.last_active_interaction_time_sec = timestamp_sec
+
     def should_initiate_conversation(self) -> bool:
         """判断是否应该主动发起对话"""
         # 如果处于人工介入状态,不应主动发起对话
@@ -446,7 +506,11 @@ class DialogueManager:
             "high": 12
         }
 
-        threshold = thresholds.get(interaction_frequency, 12)
+        threshold = thresholds.get(interaction_frequency, 24)
+        #FIXME 05-21 临时策略,两次主动发起至少48小时
+        if time.time() - self.last_active_interaction_time_sec < 2 * 24 * 3600:
+            logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: last active interaction time too short")
+            return False
 
         if hours_passed < threshold:
             return False

+ 1 - 1
pqai_agent/history_dialogue_service.py

@@ -11,7 +11,7 @@ from pqai_agent.logging_service import logger
 import time
 
 from pqai_agent import configs
-from pqai_agent.message import MessageType
+from pqai_agent.mq_message import MessageType
 
 
 class HistoryDialogueService:

+ 15 - 15
pqai_agent/message_queue_backend.py

@@ -11,21 +11,21 @@ from rocketmq import ClientConfiguration, Credentials, SimpleConsumer
 from pqai_agent import configs
 from pqai_agent import logging_service
 from pqai_agent.logging_service import logger
-from pqai_agent.message import Message, MessageType, MessageChannel
+from pqai_agent.mq_message import MqMessage, MessageType, MessageChannel
 
 
 
 class MessageQueueBackend(abc.ABC):
     @abc.abstractmethod
-    def consume(self) -> Optional[Message]:
+    def consume(self) -> Optional[MqMessage]:
         pass
 
     @abc.abstractmethod
-    def ack(self, message: Message) -> None:
+    def ack(self, message: MqMessage) -> None:
         pass
 
     @abc.abstractmethod
-    def produce(self, message: Message, msg_group: Optional[str] = None) -> None:
+    def produce(self, message: MqMessage, msg_group: Optional[str] = None) -> None:
         pass
 
     @abc.abstractmethod
@@ -37,13 +37,13 @@ class MemoryQueueBackend(MessageQueueBackend):
     def __init__(self):
         self._queue = []
 
-    def consume(self) -> Optional[Message]:
+    def consume(self) -> Optional[MqMessage]:
         return self._queue.pop(0) if self._queue else None
 
-    def ack(self, message: Message):
+    def ack(self, message: MqMessage):
         return
 
-    def produce(self, message: Message, msg_group: Optional[str] = None):
+    def produce(self, message: MqMessage, msg_group: Optional[str] = None):
         self._queue.append(message)
 
     def shutdown(self):
@@ -78,7 +78,7 @@ class AliyunRocketMQQueueBackend(MessageQueueBackend):
     def __del__(self):
         self.shutdown()
 
-    def consume(self, invisible_duration=60) -> Optional[Message]:
+    def consume(self, invisible_duration=60) -> Optional[MqMessage]:
         if not self.has_consumer:
             raise Exception("Consumer not initialized.")
         # TODO(zhoutian): invisible_duration实际是不同消息类型不同的,有些消息预期的处理时间会更长
@@ -89,7 +89,7 @@ class AliyunRocketMQQueueBackend(MessageQueueBackend):
         body = rmq_message.body.decode('utf-8')
         logger.debug("[{}]recv message body: {}".format(self.topic, body))
         try:
-            message = Message.from_json(body)
+            message = MqMessage.from_json(body)
             message._rmq_message = rmq_message
         except Exception as e:
             logger.error("Invalid message: {}. Parsing error: {}".format(body, e))
@@ -98,13 +98,13 @@ class AliyunRocketMQQueueBackend(MessageQueueBackend):
             return None
         return message
 
-    def ack(self, message: Message):
+    def ack(self, message: MqMessage):
         if not message._rmq_message:
-            raise ValueError("Message not set with _rmq_message.")
+            raise ValueError("MqMessage not set with _rmq_message.")
         logger.debug("[{}]ack message: {}".format(self.topic, message))
         self.consumer.ack(message._rmq_message)
 
-    def produce(self, message: Message, msg_group: Optional[str] = None) -> None:
+    def produce(self, message: MqMessage, msg_group: Optional[str] = None) -> None:
         if not self.has_producer:
             raise Exception("Producer not initialized.")
         message.model_config['use_enum_values'] = False
@@ -148,16 +148,16 @@ if __name__ == '__main__':
         else:
             break
 
-    send_message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
+    send_message = MqMessage.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
                                  "user_id_1", "staff_id_0",
-                                 None, int(time.time() * 1000))
+                                   None, int(time.time() * 1000))
     queue.produce(send_message)
     recv_message = queue.consume()
     print(recv_message)
     if recv_message:
         queue.ack(recv_message)
 
-    send_message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
+    send_message = MqMessage.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
                                  "user_id_1", "staff_id_0",
                                  "message_queue_backend test", int(time.time() * 1000))
     queue.produce(send_message)

+ 17 - 3
pqai_agent/message.py → pqai_agent/mq_message.py

@@ -60,6 +60,20 @@ class MessageType(int, Enum):
             104: "进入人工介入状态"
         }[code]
 
+    @staticmethod
+    def from_str(type_str: str) -> 'MessageType':
+        """从字符串转换为MessageType"""
+        upper_str = type_str.upper()
+        if upper_str == 'IMAGE':
+            # IMAGE类型特殊处理
+            upper_str = 'IMAGE_QW'
+        elif upper_str == 'VIDEO':
+            upper_str = 'VIDEO_QW'
+        try:
+            return MessageType[upper_str]
+        except KeyError:
+            raise ValueError(f"Unknown message type: {type_str}")
+
 # class MessageChannel(Enum):
 #     CORP_WECHAT = (1, "企业微信")
 #     MINI_PROGRAM = (2, "小程序")
@@ -85,7 +99,7 @@ class MessageChannel(int, Enum):
             101: "系统内部"
         }[code]
 
-class Message(BaseModel):
+class MqMessage(BaseModel):
      msgId: Optional[int] = None
      type: MessageType
      channel: MessageChannel
@@ -102,7 +116,7 @@ class Message(BaseModel):
 
      @staticmethod
      def build(type, channel, sender, receiver, content, timestamp):
-         return Message(
+         return MqMessage(
              msgId=0,
              type=type,
              channel=channel,
@@ -120,4 +134,4 @@ class Message(BaseModel):
 
      @staticmethod
      def from_json(json_str):
-         return Message.model_validate_json(json_str)
+         return MqMessage.model_validate_json(json_str)

+ 40 - 23
pqai_agent/push_service.py

@@ -6,7 +6,7 @@ from datetime import datetime
 from enum import Enum
 from concurrent.futures import ThreadPoolExecutor
 from threading import Thread
-from typing import Optional, Dict
+from typing import Optional, Dict, List
 
 import rocketmq
 from rocketmq import ClientConfiguration, Credentials, SimpleConsumer, FilterExpression
@@ -16,7 +16,7 @@ from pqai_agent.agents.message_push_agent import MessagePushAgent, DummyMessageP
 from pqai_agent.configs import apollo_config
 from pqai_agent.data_models.agent_push_record import AgentPushRecord
 from pqai_agent.logging_service import logger
-from pqai_agent.message import MessageType
+from pqai_agent.mq_message import MessageType
 
 
 class TaskType(Enum):
@@ -30,7 +30,7 @@ def generate_task_rmq_message(topic: str, staff_id: str, user_id: str, task_type
         'staff_id': staff_id,
         'user_id': user_id,
         'task_type': task_type.value,
-        # FIXME: 需要支持多模态消息
+        # NOTE:通过传入JSON支持多模态消息
         'content': content or '',
         'timestamp': int(time.time() * 1000),
     }, ensure_ascii=False).encode('utf-8')
@@ -142,25 +142,42 @@ class PushTaskWorkerPool:
                 logger.debug(f"user[{user_id}], do not initiate conversation")
                 self.consumer.ack(msg)
                 return
-            content = task['content']
+            contents: List[Dict] = json.loads(task['content'])
+            if not contents:
+                logger.debug(f"staff[{staff_id}], user[{user_id}]: empty content, do not send")
+                self.consumer.ack(msg)
+                return
             recent_dialogue = agent.dialogue_history[-10:]
             agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
-            # FIXME(zhoutian): 不应该再由agent控制,或者agent和API共享同一配置
-            if len(recent_dialogue) < 2 or staff_id not in agent_voice_whitelist:
-                message_type = MessageType.TEXT
-            else:
-                message_type = self.agent_service.response_type_detector.detect_type(
-                    recent_dialogue, content, enable_random=True)
-            response = agent.generate_response(content)
-            if response:
-                with self.agent_service.AgentDBSession() as session:
-                    msg_list = [{'type': MessageType.TEXT.value, 'content': response}]
-                    record = AgentPushRecord(staff_id=staff_id, user_id=user_id,
-                                             content=json.dumps(msg_list, ensure_ascii=False),
-                                             timestamp=int(datetime.now().timestamp()))
-                    session.add(record)
-                    session.commit()
-                self.agent_service.send_response(staff_id, user_id, response, message_type, skip_check=True)
+            messages_to_send = []
+            for item in contents:
+                if item["type"] == "text":
+                    if staff_id not in agent_voice_whitelist:
+                        message_type = MessageType.TEXT
+                    else:
+                        message_type = self.agent_service.response_type_detector.detect_type(
+                            recent_dialogue, item["content"], enable_random=True)
+                    response = agent.generate_response(item["content"])
+                    if response:
+                        messages_to_send.append({'type': message_type, 'content': response})
+                else:
+                    message_type = MessageType.from_str(item["type"])
+                    response = agent.generate_multimodal_response(item)
+                    if response:
+                        item["type"] = message_type
+                        messages_to_send.append(item)
+            current_ts = int(time.time())
+            with self.agent_service.AgentDBSession() as session:
+                msg_list = [{"type": msg["type"].value, "content": msg["content"]} for msg in messages_to_send]
+                record = AgentPushRecord(staff_id=staff_id, user_id=user_id,
+                                         content=json.dumps(msg_list, ensure_ascii=False),
+                                         timestamp=current_ts)
+                session.add(record)
+                session.commit()
+            if messages_to_send:
+                for response in messages_to_send:
+                    self.agent_service.send_multimodal_response(staff_id, user_id, response, skip_check=True)
+                agent.update_last_active_interaction_time(current_ts)
             else:
                 logger.debug(f"staff[{staff_id}], user[{user_id}]: generate empty response")
             self.consumer.ack(msg)
@@ -178,12 +195,12 @@ class PushTaskWorkerPool:
             message_to_user = push_agent.generate_message(
                 context=main_agent.get_prompt_context(None),
                 dialogue_history=self.agent_service.history_dialogue_db.get_dialogue_history_backward(
-                    staff_id, user_id, main_agent.last_interaction_time, limit=100
+                    staff_id, user_id, main_agent.last_interaction_time_ms, limit=100
                 )
             )
             if message_to_user:
-                rmq_message = generate_task_rmq_message(self.rmq_topic, staff_id, user_id, TaskType.SEND, message_to_user)
-                logger.debug(f"send message: {rmq_message.body.decode('utf-8')}")
+                rmq_message = generate_task_rmq_message(
+                    self.rmq_topic, staff_id, user_id, TaskType.SEND, json.dumps(message_to_user))
                 self.producer.send(rmq_message)
             else:
                 logger.info(f"staff[{staff_id}], user[{user_id}]: no push message generated")

+ 10 - 2
pqai_agent/rate_limiter.py

@@ -3,7 +3,11 @@
 # vim:fenc=utf-8
 
 import time
+from typing import Optional, Union, Dict
+
 from pqai_agent.logging_service import logger
+from pqai_agent.mq_message import MessageType
+
 
 class MessageSenderRateLimiter:
     MAX_CHAR_PER_SECOND = 5
@@ -11,11 +15,15 @@ class MessageSenderRateLimiter:
     def __init__(self):
         self.last_send_time = {}
 
-    def wait_for_sending(self, sender_id: str, next_message: str):
+    def wait_for_sending(self, sender_id: str, next_message: Union[str, Dict]):
         current_time = time.time()
         last_send_time = self.last_send_time.get(sender_id, 0)
         elapsed_time = current_time - last_send_time
-        required_time = len(next_message) / self.MAX_CHAR_PER_SECOND
+        if isinstance(next_message, str) or next_message["type"] in (MessageType.TEXT, MessageType.VOICE):
+            required_time = len(next_message) / self.MAX_CHAR_PER_SECOND
+        else:
+            # FIXME: 非文字消息的判断方式
+            required_time = 2
         if elapsed_time < required_time:
             logger.debug(f"Rate limit exceeded. Waiting for {required_time - elapsed_time:.2f} seconds.")
             time.sleep(required_time - elapsed_time)

+ 1 - 1
pqai_agent/response_type_detector.py

@@ -13,7 +13,7 @@ from pqai_agent import configs
 from pqai_agent import prompt_templates
 from pqai_agent.dialogue_manager import DialogueManager
 from pqai_agent.logging_service import logger
-from pqai_agent.message import MessageType
+from pqai_agent.mq_message import MessageType
 
 
 class ResponseTypeDetector:

+ 32 - 0
pqai_agent/toolkit/base.py

@@ -3,6 +3,23 @@ import functools
 import threading
 from pqai_agent.toolkit.function_tool import FunctionTool
 
+
+class ToolServiceError(Exception):
+
+    def __init__(self,
+                 exception: Optional[Exception] = None,
+                 code: Optional[str] = None,
+                 message: Optional[str] = None,
+                 extra: Optional[dict] = None):
+        if exception is not None:
+            super().__init__(exception)
+        else:
+            super().__init__(f'\nError code: {code}. Error message: {message}')
+        self.exception = exception
+        self.code = code
+        self.message = message
+        self.extra = extra
+
 def with_timeout(timeout=None):
     r"""Decorator that adds timeout functionality to functions.
 
@@ -100,3 +117,18 @@ class BaseToolkit:
                 representing the functions in the toolkit.
         """
         raise NotImplementedError("Subclasses must implement this method.")
+
+    def get_tool(self, name: str) -> FunctionTool:
+        r"""Returns a FunctionTool object by name.
+
+        Args:
+            name (str): The name of the tool to retrieve.
+
+        Returns:
+            Optional[FunctionTool]: The FunctionTool object if found, else None.
+        """
+        tools = self.get_tools()
+        for tool in tools:
+            if tool.name == name:
+                return tool
+        raise NotImplementedError("Tool not found in the toolkit.")

+ 6 - 0
pqai_agent/toolkit/message_notifier.py

@@ -37,6 +37,12 @@ class MessageNotifier(BaseToolkit):
         Returns:
             str: A confirmation message.
         """
+        if message["type"] not in ["text", "image", "gif", "video", "mini_program"]:
+            return f"Invalid message type: {message['type']}"
+        if message["type"] in ("video", "mini_program") and "title" not in message:
+            return "Title is required for video or mini_program messages."
+        if message["type"] == "mini_program" and "cover_image" not in message:
+            return "Cover image is required for mini_program messages."
         logger.info(f"Multimodal message to user: {message}")
         return 'success'
 

+ 28 - 0
pqai_agent/toolkit/search_toolkit.py

@@ -188,6 +188,33 @@ class SearchToolkit(BaseToolkit):
         except Exception as e:
             return {"error": f"Bing scraping error: {e!s}"}
 
+    def aiddit_search(self, keyword: str) -> Dict[str, Any]:
+        r"""Search using Aiddit API.
+
+        Args:
+            keyword (str): The search keyword.
+        Returns:
+            Dict[str, Any]: A dictionary containing search results.
+        """
+        url = "http://smcp-api.aiddit.com/mcp/custom/search"
+        headers = {
+            "Content-Type": "application/json",
+        }
+        data = {
+            "keyword": keyword
+        }
+        try:
+            response = requests.post(url, headers=headers, json=data)
+            response.raise_for_status()
+            resp_json = response.json()
+            if resp_json.get('code') != 0:
+                return {"error": f"Aiddit search error: {resp_json.get('message', 'Unknown error')}"}
+            resp_data = resp_json['data']
+            results = resp_data.get('results', [])[:5]
+            return {'results': results}  # Limit to 5 results
+        except requests.RequestException as e:
+            return {"error": f"Aiddit search error: {e!s}"}
+
     def get_tools(self) -> List[FunctionTool]:
         r"""Returns a list of FunctionTool objects representing the
         functions in the toolkit.
@@ -199,4 +226,5 @@ class SearchToolkit(BaseToolkit):
         return [
             FunctionTool(self.search_baidu),
             FunctionTool(self.search_bing),
+            FunctionTool(self.aiddit_search),
         ]

+ 5 - 5
pqai_agent_server/agent_server.py

@@ -6,7 +6,7 @@ from pqai_agent import configs, logging_service
 from pqai_agent.agent_service import AgentService
 from pqai_agent.chat_service import ChatServiceType
 from pqai_agent.logging_service import logger
-from pqai_agent.message import MessageType, Message, MessageChannel
+from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
 from pqai_agent.message_queue_backend import AliyunRocketMQQueueBackend, MemoryQueueBackend
 from pqai_agent.push_service import PushTaskWorkerPool, PushScanThread
 from pqai_agent.user_manager import LocalUserManager, LocalUserRelationManager, MySQLUserManager, \
@@ -93,14 +93,14 @@ if __name__ == "__main__":
         receiver = '1688855931724582'
         if text in (MessageType.AGGREGATION_TRIGGER.name,
                     MessageType.HUMAN_INTERVENTION_END.name):
-            message = Message.build(
+            message = MqMessage.build(
                 MessageType.__members__.get(text),
                 MessageChannel.CORP_WECHAT,
                 sender, receiver, None, int(time.time() * 1000))
         else:
-            message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
-                                    sender,receiver, text, int(time.time() * 1000)
-                                    )
+            message = MqMessage.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
+                                      sender, receiver, text, int(time.time() * 1000)
+                                      )
         message.msgId = message_id
         receive_queue.produce(message)
         time.sleep(0.1)

+ 1 - 1
pqai_agent_server/utils/prompt_util.py

@@ -9,7 +9,7 @@ from pqai_agent import logging_service, chat_service
 from pqai_agent.response_type_detector import ResponseTypeDetector
 from pqai_agent.user_profile_extractor import UserProfileExtractor
 from pqai_agent.dialogue_manager import DialogueManager
-from pqai_agent.message import MessageType
+from pqai_agent.mq_message import MessageType
 from pqai_agent.utils.prompt_utils import format_agent_profile
 
 logger = logging_service.logger

+ 5 - 5
scripts/mq_sender.py

@@ -1,6 +1,6 @@
 import sys
 from pqai_agent import configs
-from pqai_agent.message import Message, MessageType, MessageChannel
+from pqai_agent.mq_message import MqMessage, MessageType, MessageChannel
 from pqai_agent.message_queue_backend import AliyunRocketMQQueueBackend
 import time
 from argparse import ArgumentParser
@@ -33,13 +33,13 @@ if __name__ == '__main__':
         receiver = args.staff_id
         if text in (MessageType.AGGREGATION_TRIGGER.name,
                     MessageType.HUMAN_INTERVENTION_END.name):
-            message = Message.build(
+            message = MqMessage.build(
                 MessageType.__members__.get(text),
                 MessageChannel.CORP_WECHAT,
                 sender, receiver, None, int(time.time() * 1000))
         else:
-            message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
-                                    sender,receiver, text, int(time.time() * 1000)
-                                    )
+            message = MqMessage.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
+                                      sender, receiver, text, int(time.time() * 1000)
+                                      )
         message.msgId = message_id
         receive_queue.produce(message)

+ 3 - 3
scripts/resend_lost_message.py

@@ -5,7 +5,7 @@
 from datetime import datetime
 import re
 from pqai_agent import configs
-from pqai_agent.message import MessageChannel, Message, MessageType
+from pqai_agent.mq_message import MessageChannel, MqMessage, MessageType
 from pqai_agent.message_queue_backend import AliyunRocketMQQueueBackend
 from pqai_agent.user_manager import MySQLUserRelationManager
 
@@ -102,8 +102,8 @@ def main():
                     # Check if user has already been processed
                     if user_id in processed_users:
                         break
-                    message = Message.build(message_type, MessageChannel.CORP_WECHAT,
-                                            staff_id, user_id, response, int(timestamp.timestamp() * 1000))
+                    message = MqMessage.build(message_type, MessageChannel.CORP_WECHAT,
+                                              staff_id, user_id, response, int(timestamp.timestamp() * 1000))
                     print(message)
                     # Send the message
                     send_queue.produce(message)

+ 6 - 6
tests/unit_test.py

@@ -6,7 +6,7 @@ import pytest
 from unittest.mock import Mock, MagicMock
 from pqai_agent.agent_service import AgentService, MemoryQueueBackend
 from pqai_agent.dialogue_manager import DialogueState, TimeContext
-from pqai_agent.message import MessageType, Message, MessageChannel
+from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
 from pqai_agent.response_type_detector import ResponseTypeDetector
 from pqai_agent.user_manager import LocalUserManager
 import time
@@ -113,7 +113,7 @@ def test_normal_conversation_flow(test_env):
     service.get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
 
     # 准备测试消息
-    test_msg = Message.build(
+    test_msg = MqMessage.build(
         MessageType.TEXT, MessageChannel.CORP_WECHAT,
         'user_id_0', 'staff_id_0', '你好', int(time.time() * 1000))
     queues.receive_queue.produce(test_msg)
@@ -136,11 +136,11 @@ def test_aggregated_conversation_flow(test_env):
 
     # 准备测试消息
     ts_begin = int(time.time() * 1000)
-    test_msg = Message.build(
+    test_msg = MqMessage.build(
         MessageType.TEXT, MessageChannel.CORP_WECHAT,
         'user_id_0', 'staff_id_0', '你好', ts_begin)
     queues.receive_queue.produce(test_msg)
-    test_msg = Message.build(
+    test_msg = MqMessage.build(
         MessageType.TEXT, MessageChannel.CORP_WECHAT,
         'user_id_0', 'staff_id_0', '我是老李', ts_begin + 500)
     queues.receive_queue.produce(test_msg)
@@ -162,7 +162,7 @@ def test_aggregated_conversation_flow(test_env):
     assert sent_msg is None
 
     # 模拟定时器产生空消息触发响应
-    service.process_single_message(Message.build(
+    service.process_single_message(MqMessage.build(
         MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
         'user_id_0', 'staff_id_0', None, ts_begin + 2000
     ))
@@ -178,7 +178,7 @@ def test_human_intervention_trigger(test_env):
     service.get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
 
     # 准备需要人工干预的消息
-    test_msg = Message.build(
+    test_msg = MqMessage.build(
         MessageType.TEXT, MessageChannel.CORP_WECHAT,
         "user_id_0", "staff_id_0",
         "我需要帮助!", int(time.time() * 1000)