Browse Source

Merge remote-tracking branch 'origin/feature/202503-init' into feature/luojunhui-20250514-server

# Conflicts:
#	pqai_agent/dialogue_manager.py
luojunhui 1 month ago
parent
commit
941daf65a7

+ 1 - 0
.gitignore

@@ -58,3 +58,4 @@ docs/_build/
 # PyBuilder
 target/
 
+user_profiles/

+ 92 - 113
pqai_agent/agent_service.py

@@ -13,24 +13,25 @@ import threading
 import traceback
 
 import apscheduler.triggers.cron
+import rocketmq
 from apscheduler.schedulers.background import BackgroundScheduler
 
 from pqai_agent import configs
-from pqai_agent import logging_service
 from pqai_agent.configs import apollo_config
+from pqai_agent.exceptions import NoRetryException
 from pqai_agent.logging_service import logger
 from pqai_agent import chat_service
 from pqai_agent.chat_service import CozeChat, ChatServiceType
 from pqai_agent.dialogue_manager import DialogueManager, DialogueState, DialogueStateCache
+from pqai_agent.history_dialogue_service import HistoryDialogueDatabase
+from pqai_agent.push_service import PushScanThread, PushTaskWorkerPool
 from pqai_agent.rate_limiter import MessageSenderRateLimiter
 from pqai_agent.response_type_detector import ResponseTypeDetector
-from pqai_agent.user_manager import UserManager, LocalUserManager, MySQLUserManager, MySQLUserRelationManager, UserRelationManager, \
-    LocalUserRelationManager
+from pqai_agent.user_manager import UserManager, UserRelationManager
 from pqai_agent.message_queue_backend import MessageQueueBackend, MemoryQueueBackend, AliyunRocketMQQueueBackend
 from pqai_agent.user_profile_extractor import UserProfileExtractor
 from pqai_agent.message import MessageType, Message, MessageChannel
 
-
 class AgentService:
     def __init__(
         self,
@@ -41,6 +42,8 @@ class AgentService:
         user_relation_manager: UserRelationManager,
         chat_service_type: ChatServiceType = ChatServiceType.OPENAI_COMPATIBLE
     ):
+        self.config = configs.get()
+
         self.receive_queue = receive_backend
         self.send_queue = send_backend
         self.human_queue = human_backend
@@ -52,8 +55,8 @@ class AgentService:
         self.user_profile_extractor = UserProfileExtractor()
         self.response_type_detector = ResponseTypeDetector()
         self.agent_registry: Dict[str, DialogueManager] = {}
+        self.history_dialogue_db = HistoryDialogueDatabase(self.config['storage']['user']['mysql'])
 
-        self.config = configs.get()
         chat_config = self.config['chat_api']['openai_compatible']
         self.text_model_name = chat_config['text_model']
         self.multimodal_model_name = chat_config['multimodal_model']
@@ -80,6 +83,13 @@ class AgentService:
         self.process_thread = None
         self._sigint_cnt = 0
 
+        # Push相关
+        self.push_task_producer = None
+        self.push_task_consumer = None
+        self._init_push_task_queue()
+        self.next_push_disabled = True
+        self._resume_unfinished_push_task()
+
         self.send_rate_limiter = MessageSenderRateLimiter()
 
     def setup_initiative_conversations(self, schedule_params: Optional[Dict] = None):
@@ -102,7 +112,8 @@ class AgentService:
                 topic,
                 has_consumer=True, has_producer=True,
                 group_id=mq_conf['scheduler_group'],
-                topic_type='DELAY'
+                topic_type='DELAY',
+                await_duration=5
             )
             self.msg_scheduler_thread = threading.Thread(target=self.process_scheduler_events)
             self.msg_scheduler_thread.start()
@@ -127,13 +138,15 @@ class AgentService:
         else:
             logger.warning(f"Unknown message type: {msg.type}")
 
-    def _get_agent_instance(self, staff_id: str, user_id: str) -> DialogueManager:
+    def get_agent_instance(self, staff_id: str, user_id: str) -> DialogueManager:
         """获取Agent实例"""
         agent_key = 'agent_{}_{}'.format(staff_id, user_id)
         if agent_key not in self.agent_registry:
             self.agent_registry[agent_key] = DialogueManager(
                 staff_id, user_id, self.user_manager, self.agent_state_cache)
-        return self.agent_registry[agent_key]
+        agent = self.agent_registry[agent_key]
+        agent.refresh_profile()
+        return agent
 
     def process_messages(self):
         """持续处理接收队列消息"""
@@ -143,15 +156,18 @@ class AgentService:
                 try:
                     self.process_single_message(message)
                     self.receive_queue.ack(message)
+                except NoRetryException as e:
+                    logger.error("Error processing message and skip retry: {}".format(e))
+                    self.receive_queue.ack(message)
                 except Exception as e:
-                    logger.error("Error processing message: {}".format(e))
-                    traceback.print_exc()
-            time.sleep(1)
+                    error_stack = traceback.format_exc()
+                    logger.error("Error processing message: {}, {}".format(e, error_stack))
+            time.sleep(0.5)
         logger.info("Message processing thread exit")
 
     def start(self, blocking=False):
         self.running = True
-        self.process_thread = threading.Thread(target=service.process_messages)
+        self.process_thread = threading.Thread(target=self.process_messages)
         self.process_thread.start()
         self.setup_scheduler()
         # 只有企微场景需要主动发起
@@ -217,10 +233,10 @@ class AgentService:
 
         # 获取用户信息和Agent实例
         user_profile = self.user_manager.get_user_profile(user_id)
-        agent = self._get_agent_instance(staff_id, user_id)
+        agent = self.get_agent_instance(staff_id, user_id)
         if not agent.is_valid():
             logger.error(f"staff[{staff_id}] user[{user_id}]: agent is invalid")
-            return
+            raise Exception('agent is invalid')
 
         # 更新对话状态
         logger.debug("process message: {}".format(message))
@@ -242,13 +258,13 @@ class AgentService:
                 resp = self._get_chat_response(user_id, agent, message_text)
                 if resp:
                     recent_dialogue = agent.dialogue_history[-10:]
-                    agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist"))
+                    agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
                     if len(recent_dialogue) < 2 or staff_id not in agent_voice_whitelist:
                         message_type = MessageType.TEXT
                     else:
                         message_type = self.response_type_detector.detect_type(
                             recent_dialogue[:-1], recent_dialogue[-1], enable_random=True)
-                    self._send_response(staff_id, user_id, resp, message_type)
+                    self.send_response(staff_id, user_id, resp, message_type)
             else:
                 logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
             # 当前消息处理成功,commit并持久化agent状态
@@ -257,15 +273,15 @@ class AgentService:
             agent.rollback_state()
             raise e
 
-    def _send_response(self, staff_id, user_id, response, message_type: MessageType, skip_check=False):
+    def send_response(self, staff_id, user_id, response, message_type: MessageType, skip_check=False):
         logger.warning(f"staff[{staff_id}] user[{user_id}]: response[{message_type}] {response}")
         current_ts = int(time.time() * 1000)
         user_tags = self.user_relation_manager.get_user_tags(user_id)
-        white_list_tags = set(apollo_config.get_json_value("agent_response_whitelist_tags"))
+        white_list_tags = set(apollo_config.get_json_value("agent_response_whitelist_tags", []))
         hit_white_list_tags = len(set(user_tags).intersection(white_list_tags)) > 0
         # FIXME(zhoutian)
         # 测试期间临时逻辑,只发送特定的账号或特定用户
-        staff_white_lists = set(apollo_config.get_json_value("agent_response_whitelist_staffs"))
+        staff_white_lists = set(apollo_config.get_json_value("agent_response_whitelist_staffs", []))
         if not (staff_id in staff_white_lists or hit_white_list_tags or skip_check):
             logger.warning(f"staff[{staff_id}] user[{user_id}]: skip reply")
             return
@@ -286,23 +302,75 @@ class AgentService:
             int(time.time() * 1000)
         ))
 
+    def _init_push_task_queue(self):
+        credentials = rocketmq.Credentials()
+        mq_conf = configs.get()['mq']
+        rmq_client_conf = rocketmq.ClientConfiguration(mq_conf['endpoints'], credentials, mq_conf['instance_id'])
+        rmq_topic = mq_conf['push_tasks_topic']
+        rmq_group = mq_conf['push_tasks_group']
+        self.push_task_rmq_topic = rmq_topic
+        self.push_task_producer = rocketmq.Producer(rmq_client_conf, (rmq_topic,))
+        self.push_task_producer.startup()
+        self.push_task_consumer = rocketmq.SimpleConsumer(rmq_client_conf, rmq_group, await_duration=5)
+        self.push_task_consumer.startup()
+        self.push_task_consumer.subscribe(rmq_topic)
+
+
+    def _resume_unfinished_push_task(self):
+        def run_unfinished_push_task():
+            logger.info("start to resume unfinished push task")
+            push_task_worker_pool = PushTaskWorkerPool(
+                self, self.push_task_rmq_topic, self.push_task_consumer, self.push_task_producer)
+            push_task_worker_pool.start()
+            push_task_worker_pool.wait_to_finish()
+            self.next_push_disabled = False
+            logger.info("unfinished push tasks should be finished")
+        thread = threading.Thread(target=run_unfinished_push_task)
+        thread.start()
+
     def _check_initiative_conversations(self):
         logger.info("start to check initiative conversations")
+        if self.next_push_disabled:
+            logger.info("previous push tasks in processing, next push is disabled")
+            return
         if not DialogueManager.is_time_suitable_for_active_conversation():
             logger.info("time is not suitable for active conversation")
             return
-        white_list_tags = set(apollo_config.get_json_value('agent_initiate_whitelist_tags'))
+
+        push_scan_threads = []
+        for staff in self.user_relation_manager.list_staffs():
+            staff_id = staff['third_party_user_id']
+            scan_thread = threading.Thread(target=PushScanThread(
+                staff_id, self, self.push_task_rmq_topic, self.push_task_producer).run)
+            scan_thread.start()
+            push_scan_threads.append(scan_thread)
+
+        push_task_worker_pool = PushTaskWorkerPool(
+            self, self.push_task_rmq_topic, self.push_task_consumer, self.push_task_producer)
+        push_task_worker_pool.start()
+        for thread in push_scan_threads:
+            thread.join()
+        # 由于扫描和生成异步,两次扫描之间可能消息并未处理完,会有重复生成任务的情况,因此需等待上一次任务结束
+        # 问题在于,如果每次创建出新的PushTaskWorkerPool,在上次任务有未处理完的消息即退出时,会有未处理的消息堆积
+        push_task_worker_pool.wait_to_finish()
+
+    def _check_initiative_conversations_v1(self):
+        logger.info("start to check initiative conversations")
+        if not DialogueManager.is_time_suitable_for_active_conversation():
+            logger.info("time is not suitable for active conversation")
+            return
+        white_list_tags = set(apollo_config.get_json_value('agent_initiate_whitelist_tags', []))
         first_initiate_tags = set(apollo_config.get_json_value('agent_first_initiate_whitelist_tags', []))
         # 合并白名单,减少配置成本
         white_list_tags.update(first_initiate_tags)
-        voice_tags = set(apollo_config.get_json_value('agent_initiate_by_voice_tags'))
+        voice_tags = set(apollo_config.get_json_value('agent_initiate_by_voice_tags', []))
 
 
         """定时检查主动发起对话"""
         for staff_user in self.user_relation_manager.list_staff_users():
             staff_id = staff_user['staff_id']
             user_id = staff_user['user_id']
-            agent = self._get_agent_instance(staff_id, user_id)
+            agent = self.get_agent_instance(staff_id, user_id)
             should_initiate = agent.should_initiate_conversation()
             user_tags = self.user_relation_manager.get_user_tags(user_id)
 
@@ -326,7 +394,7 @@ class AgentService:
                             message_type = MessageType.VOICE
                         else:
                             message_type = MessageType.TEXT
-                        self._send_response(staff_id, user_id, resp, message_type, skip_check=True)
+                        self.send_response(staff_id, user_id, resp, message_type, skip_check=True)
                     agent.persist_state()
                 except Exception as e:
                     # FIXME:虽然需要主动唤起的用户同时发来消息的概率很低,但仍可能会有并发冲突
@@ -398,93 +466,4 @@ class AgentService:
         pattern = r'\[?\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\]?'
         response = re.sub(pattern, '', response)
         response = response.strip()
-        return response
-
-if __name__ == "__main__":
-    config = configs.get()
-    logging_service.setup_root_logger()
-    logger.warning("current env: {}".format(configs.get_env()))
-    scheduler_logger = logging.getLogger('apscheduler')
-    scheduler_logger.setLevel(logging.WARNING)
-
-    use_aliyun_mq = config['debug_flags']['use_aliyun_mq']
-
-    # 初始化不同队列的后端
-    if use_aliyun_mq:
-        receive_queue = AliyunRocketMQQueueBackend(
-            config['mq']['endpoints'],
-            config['mq']['instance_id'],
-            config['mq']['receive_topic'],
-            has_consumer=True, has_producer=True,
-            group_id=config['mq']['receive_group'],
-            topic_type='FIFO'
-        )
-        send_queue = AliyunRocketMQQueueBackend(
-            config['mq']['endpoints'],
-            config['mq']['instance_id'],
-            config['mq']['send_topic'],
-            has_consumer=False, has_producer=True,
-            topic_type='FIFO'
-        )
-    else:
-        receive_queue = MemoryQueueBackend()
-        send_queue = MemoryQueueBackend()
-    human_queue = MemoryQueueBackend()
-
-    # 初始化用户管理服务
-    # FIXME(zhoutian): 如果不使用MySQL,此数据库配置非必须
-    user_db_config = config['storage']['user']
-    staff_db_config = config['storage']['staff']
-    wecom_db_config = config['storage']['user_relation']
-    if config['debug_flags'].get('use_local_user_storage', False):
-        user_manager = LocalUserManager()
-        user_relation_manager = LocalUserRelationManager()
-    else:
-        user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
-        user_relation_manager = MySQLUserRelationManager(
-            user_db_config['mysql'], wecom_db_config['mysql'],
-            config['storage']['staff']['table'],
-            user_db_config['table'],
-            wecom_db_config['table']['staff'],
-            wecom_db_config['table']['relation'],
-            wecom_db_config['table']['user']
-        )
-
-    # 创建Agent服务
-    service = AgentService(
-        receive_backend=receive_queue,
-        send_backend=send_queue,
-        human_backend=human_queue,
-        user_manager=user_manager,
-        user_relation_manager=user_relation_manager,
-        chat_service_type=ChatServiceType.COZE_CHAT
-    )
-
-    if not config['debug_flags'].get('console_input', False):
-        service.start(blocking=True)
-        sys.exit(0)
-    else:
-        service.start()
-
-    message_id = 0
-    while service.running:
-        print("Input next message: ")
-        text = sys.stdin.readline().strip()
-        if not text:
-            continue
-        message_id += 1
-        sender = '7881301903997433'
-        receiver = '1688855931724582'
-        if text in (MessageType.AGGREGATION_TRIGGER.name,
-                    MessageType.HUMAN_INTERVENTION_END.name):
-            message = Message.build(
-                MessageType.__members__.get(text),
-                MessageChannel.CORP_WECHAT,
-                sender, receiver, None, int(time.time() * 1000))
-        else:
-            message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
-                sender,receiver, text, int(time.time() * 1000)
-            )
-        message.msgId = message_id
-        receive_queue.produce(message)
-        time.sleep(0.1)
+        return response

+ 86 - 18
pqai_agent/agents/message_push_agent.py

@@ -1,3 +1,4 @@
+import datetime
 from typing import Optional, List, Dict
 
 from pqai_agent.agents.simple_chat_agent import SimpleOpenAICompatibleChatAgent
@@ -9,11 +10,10 @@ from pqai_agent.toolkit.message_notifier import MessageNotifier
 
 DEFAULT_SYSTEM_PROMPT = '''
 <基本设定>
-你是一位熟悉中老年用户交流习惯的微信客服Agent
+你是一位熟悉中老年用户交流习惯的微信客服。
 你擅长以下事项:
-* 倾听、引导和共情,在对话中自然促进用户互动
 * 理解中老年人的典型情感需求、对话习惯
-* 分析用户的微信名、头像,以适合的话术与用户建立联系
+* 倾听、引导和共情,在对话中自然促进用户互动
 
 你的工作方法论:
 * 分析用户请求以确定核心需求
@@ -26,11 +26,44 @@ DEFAULT_SYSTEM_PROMPT = '''
 * 所有的思考和回答都要用工作语言
 </语言设定>
 
-<通用话术>
-* 时间锚点:"早上好!今天阳光这么好,您打算做点什么让自己开心的事呀?"
-* 轻量求助:"听说最近好多长辈在学手机拍照技巧,您有没有什么实用小窍门能教教我呀?"
-* 正向引导:"这个季节最适合喝养生茶啦,您平时喜欢枸杞红枣茶还是菊花茶呀?"
-</通用话术>
+<社交阶段划分>
+* 破冰试探期
+* 角色探索期
+* 情感联结期
+</社交阶段划分>
+
+<通用话术列表>
+<话术>
+* 简介:简单自我介绍,并向用户发出交友邀请
+* 适用情形:破冰试探
+* 内容:根据你的人设做出简短自我介绍,并向用户发出交友邀请
+* 要求:简短亲切,1-2句话;模拟微信聊天的回复格式,分段清晰
+</话术>
+<话术>
+* 简介:使用用户微信名做藏头诗,进行问候
+* 适用情形:破冰试探
+* 内容:首先进行用户问候,然后介绍藏头诗,最后以一个让老年人有动力继续聊天的问句结尾
+* 要求:最后的问句要与藏头诗或用户自身经历有关,与藏头诗自然承接,无需和用户画像其他内容相关
+</话术>
+<话术>
+* 简介:介绍自己的兴趣并询问用户的兴趣爱好
+* 适用情形:与用户有一定交流基础之后
+* 内容:根据用户头像分析用户的特点、可能的兴趣爱好,作为参考,表述自己有相同的兴趣爱好,并举一些简短的例子 ,然后询问用户的兴趣爱好
+* 要求:询问无需和用户画像中其他信息有关;说明引发你问候的是用户头像;简短亲切,2-3句话 30字左右;如无用户信息或行为,不要根据联想杜撰用户偏好/行为
+</话术>
+<话术>
+* 简介:对用户进行节日/节气相关问候
+* 适用情形:不限
+* 内容:结合具体节假日及其习俗产生问候,以一个让老年人有动力继续聊天的问句结尾,与前面的问候自然承接
+* 要求:根据今日或近日实际日期,不要假设日期和节日;忽略小众节假日,和根据最近的节假日产生问候,如临近或刚过完重要节日,可询问节日安排或节日经历;简短亲切,2-3句话 30字左右;如无用户信息或行为,不要根据联想杜撰用户偏好/行为
+</话术>
+<话术>
+* 简介:询问用户当日计划安排并产生问候
+* 适用情形:与用户有一定交流基础之后
+* 内容:向用户介绍你的今日安排以及询问用户的今日安排
+* 要求:简短亲切,1-2句话,像用户熟悉的晚辈一样问候沟通;模拟微信聊天的回复格式,分段清晰
+</话术>
+</通用话术列表>
 
 <心理学技巧>
 * 怀旧效应:可以用"当年/以前"触发美好回忆
@@ -40,7 +73,7 @@ DEFAULT_SYSTEM_PROMPT = '''
 
 <风险规避原则>
 * 避免过度打扰和重复:注意分析历史对话
-* 避免过度解读
+* 避免过度解读:不要过度解读用户的信息
 * 文化适配:注意不同地域的用户文化差异
 * 准确性要求:不要使用虚构的信息
 </风险规避原则>
@@ -56,9 +89,11 @@ You are operating in an agent loop, iteratively completing tasks through these s
 </agent_loop>
 '''
 
-QUERY_PROMPT_TEMPLATE = """现在,请通过多步思考,选择合适的方法向一位用户发起问候。
-# 已知用户的信息
-用户信息:
+QUERY_PROMPT_TEMPLATE = """现在,请通过多步思考,以客服的角色,选择合适的方法向一位用户发起问候。
+# 客服的基本信息
+{formatted_staff_profile}
+# 用户的信息
+- 微信昵称:{nickname}
 - 姓名:{name}
 - 头像:{avatar}
 - 偏好的称呼:{preferred_nickname}
@@ -74,10 +109,12 @@ QUERY_PROMPT_TEMPLATE = """现在,请通过多步思考,选择合适的方
 时间:{current_datetime}
 
 注意对话信息的格式为: [角色][时间]对话内容
+注意分析客服和用户当前的社交阶段,先确立本次问候的目的。
 注意一定要分析对话信息中的时间,避免和当前时间段不符的内容!注意一定要结合历史的对话情况进行分析和问候方式的选择!
-可以使用analyse_image分析用户头像。
+如有必要,可以使用analyse_image分析用户头像。
 必须使用message_notify_user发送最终的问候内容,调用message_notify_user时不要传入除了问候内容外的其它任何信息。
-Please think step by step.
+注意每次问候只使用一种话术。
+Now, start to process your task. Please think step by step.
  """
 
 class MessagePushAgent(SimpleOpenAICompatibleChatAgent):
@@ -95,14 +132,39 @@ class MessagePushAgent(SimpleOpenAICompatibleChatAgent):
         ])
         super().__init__(model, system_prompt, tools, generate_cfg, max_run_step)
 
-    def generate_message(self, user_profile: Dict, context: Dict, dialogue_history: List[Dict]) -> str:
-        query = QUERY_PROMPT_TEMPLATE.format(**user_profile, **context, dialogue_history=dialogue_history)
+    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> str:
+        formatted_dialogue = MessagePushAgent.compose_dialogue(dialogue_history)
+        query = QUERY_PROMPT_TEMPLATE.format(**context, dialogue_history=formatted_dialogue)
         self.run(query)
         for tool_call in reversed(self.tool_call_records):
             if tool_call['name'] == MessageNotifier.message_notify_user.__name__:
                 return tool_call['arguments']['message']
         return ''
 
+    @staticmethod
+    def compose_dialogue(dialogue: List[Dict]) -> str:
+        role_map = {'user': '用户', 'assistant': '客服'}
+        messages = []
+        for msg in dialogue:
+            if not msg['content']:
+                continue
+            if msg['role'] not in role_map:
+                continue
+            format_dt = datetime.datetime.fromtimestamp(msg['timestamp'] / 1000).strftime('%Y-%m-%d %H:%M:%S')
+            messages.append('[{}][{}]{}'.format(role_map[msg['role']], format_dt, msg['content']))
+        return '\n'.join(messages)
+
+class DummyMessagePushAgent(MessagePushAgent):
+    """A dummy agent for testing purposes."""
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def generate_message(self, context: Dict, dialogue_history: List[Dict]) -> str:
+        logger.debug(f"DummyMessagePushAgent.generate_message called, context: {context}")
+        return "测试消息: {agent_name} -> {nickname}".format(**context)
+
+
 if __name__ == '__main__':
     import pqai_agent.logging_service
     pqai_agent.logging_service.setup_root_logger()
@@ -118,9 +180,15 @@ if __name__ == '__main__':
         'interests': ['钓鱼', '旅游']
     }
     test_context = {
-        "current_datetime": "2025-05-13 08:00:00",
+        "current_datetime": "2025-05-12 08:00:00",
+        **test_user_profile
     }
-    response = agent.generate_message(test_user_profile, test_context, [])
+    def create_ts(year, month, day, hour, minute):
+        return datetime.datetime(year, month, day, hour, minute).timestamp() * 1000
+    messages = [
+        {"role": "assistant", "content": "月哥,早上好!看到您的头像是一片宁静的户外风景,感觉您一定很喜欢大自然吧?今天天气不错,您有什么计划吗?", "timestamp": create_ts(2025, 5, 10, 8, 0)},
+    ]
+    response = agent.generate_message(test_context, messages)
     print(response)
 
 

+ 4 - 2
pqai_agent/configs/dev.yaml

@@ -36,7 +36,7 @@ storage:
 agent_behavior:
   message_aggregation_sec: 3
   active_conversation_schedule_param:
-    minute: 24,54
+    second: 24,54
 
 chat_api:
   coze:
@@ -71,4 +71,6 @@ mq:
   receive_group: qywx_receive_msg
   send_topic: qywx_send_msg
   scheduler_topic: agent_scheduler_event
-  scheduler_group: agent_scheduler_event
+  scheduler_group: agent_scheduler_event
+  push_tasks_topic: agent_push_tasks_dev
+  push_tasks_group: agent_push_tasks_dev

+ 3 - 1
pqai_agent/configs/prod.yaml

@@ -70,4 +70,6 @@ mq:
   receive_group: qywx_receive_msg
   send_topic: qywx_send_msg
   scheduler_topic: agent_scheduler_event
-  scheduler_group: agent_scheduler_event
+  scheduler_group: agent_scheduler_event
+  push_tasks_topic: agent_push_tasks
+  push_tasks_group: agent_push_tasks

+ 40 - 36
pqai_agent/dialogue_manager.py

@@ -23,6 +23,8 @@ from pqai_agent.message import MessageType, Message
 from pqai_agent.toolkit.lark_alert_for_human_intervention import LarkAlertForHumanIntervention
 from pqai_agent.toolkit.lark_sheet_record_for_human_intervention import LarkSheetRecordForHumanIntervention
 from pqai_agent.user_manager import UserManager
+from pqai_agent.utils import prompt_utils
+
 
 class DummyVectorMemoryManager:
     def __init__(self, user_id):
@@ -142,15 +144,20 @@ class DialogueManager:
             return TimeContext.NIGHT
 
     def is_valid(self):
-        if not self.staff_profile.get('agent_name', None):
+        if not self.staff_profile.get('name', None) and not self.staff_profile.get('agent_name', None):
             return False
         return True
 
+    def refresh_profile(self):
+        self.staff_profile = self.user_manager.get_staff_profile(self.staff_id)
+
     def _recover_state(self):
         self.current_state, self.previous_state = self.state_cache.get_state(self.staff_id, self.user_id)
 
         # 从数据库恢复对话状态
-        self.dialogue_history = self.history_dialogue_service.get_dialogue_history(self.staff_id, self.user_id)
+        minutes_to_get = 5 * 24 * 60
+        self.dialogue_history = self.history_dialogue_service.get_dialogue_history(
+            self.staff_id, self.user_id, minutes_to_get)
         if self.dialogue_history:
             self.last_interaction_time = self.dialogue_history[-1]['timestamp']
             if self.current_state == DialogueState.MESSAGE_AGGREGATING:
@@ -160,8 +167,8 @@ class DialogueManager:
                         self.unprocessed_messages.append(entry['content'])
                         break
         else:
-            # 默认设置为24小时前
-            self.last_interaction_time = int(time.time() * 1000) - 24 * 3600 * 1000
+            # 默认设置
+            self.last_interaction_time = int(time.time() * 1000) - minutes_to_get * 60 * 1000
         time_for_read = datetime.fromtimestamp(self.last_interaction_time / 1000).strftime("%Y-%m-%d %H:%M:%S")
         logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: state: {self.current_state.name}, last_interaction: {time_for_read}")
 
@@ -323,36 +330,16 @@ class DialogueManager:
         # 默认为闲聊状态
         return DialogueState.CHITCHAT
 
-    def _trigger_human_intervention(self, reason: str) -> None:
-        """触发人工介入"""
-        # 记录人工介入事件
-        # FIXME: 重启即丢失
-        event = {
-            "timestamp": int(time.time() * 1000),
-            "reason": reason,
-            "dialogue_context": self.dialogue_history[-10:]
-        }
-
-        # 更新用户资料中的人工介入历史
-        if "human_intervention_history" not in self.user_profile:
-            self.user_profile["human_intervention_history"] = []
-
-        self.user_profile["human_intervention_history"].append(event)
-        self.user_manager.save_user_profile(self.user_id, self.user_profile)
-
-        # 发送告警
-        self._send_human_intervention_alert()
-
-    def _send_human_intervention_alert(self, reason: Optional[str] = None) -> None:
+    def _send_alert(self, alert_type: str, reason: Optional[str] = None) -> None:
         time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
-        staff_info = f"{self.staff_profile.get('agent_name', '未知')}[{self.staff_id}]"
+        staff_info = f"{self.staff_profile.get('name', '未知')}[{self.staff_id}]"
         user_info = f"{self.user_profile.get('nickname', '未知')}[{self.user_id}]"
         alert_message = f"""
-        人工介入告警
+        {alert_type}告警
         员工: {staff_info}
         用户: {user_info}
         时间: {time_str}
-        原因:{reason if reason else '未知'}
+        原因:{reason if reason else "未知"}
         最近对话:
         """
 
@@ -370,13 +357,17 @@ class DialogueManager:
             dialogue_to_send.append(f"[{role_map[role]}]{dialogue['content']}")
         alert_message += '\n'.join(dialogue_to_send)
 
-        ack_url = "http://ai-wechat-hook.piaoquantv.com/manage/insertEvent?" \
-                f"sender={self.user_id}&receiver={self.staff_id}&type={MessageType.HUMAN_INTERVENTION_END.value}&content=OPERATION"
+        if alert_type == '人工介入':
+            ack_url = "http://ai-wechat-hook.piaoquantv.com/manage/insertEvent?" \
+                      f"sender={self.user_id}&receiver={self.staff_id}&type={MessageType.HUMAN_INTERVENTION_END.value}&content=OPERATION"
+        else:
+            ack_url = None
 
         LarkAlertForHumanIntervention().send_lark_alert_for_human_intervention(alert_message, ack_url)
-        LarkSheetRecordForHumanIntervention().send_lark_sheet_record_for_human_intervention(
-            staff_info, user_info, '\n'.join(dialogue_to_send), reason
-        )
+        if alert_type == '人工介入':
+            LarkSheetRecordForHumanIntervention().send_lark_sheet_record_for_human_intervention(
+                staff_info, user_info, '\n'.join(dialogue_to_send), reason
+            )
 
     def resume_from_human_intervention(self) -> None:
         """从人工介入状态恢复"""
@@ -394,7 +385,14 @@ class DialogueManager:
             reason = llm_response.replace('<人工介入>', '')
             logger.warning(f'staff[{self.staff_id}], user[{self.user_id}]: human intervention triggered, reason: {reason}')
             self.do_state_change(DialogueState.HUMAN_INTERVENTION)
-            self._send_human_intervention_alert(reason)
+            self._send_alert('人工介入', reason)
+            return None
+
+        if '<结束>' in llm_response or '<负向情绪结束>' in llm_response:
+            logger.warning(f'staff[{self.staff_id}], user[{self.user_id}]: conversation ended')
+            self.do_state_change(DialogueState.FAREWELL)
+            if '<负向情绪结束>' in llm_response:
+                self._send_alert("用户负向情绪")
             return None
 
         """根据当前状态处理LLM响应,如果处于人工介入状态则返回None"""
@@ -454,6 +452,8 @@ class DialogueManager:
 
     @staticmethod
     def is_time_suitable_for_active_conversation(time_context=None) -> bool:
+        if configs.get_env() == 'dev':
+            return True
         if not time_context:
             time_context = DialogueManager.get_time_context()
         if time_context in [TimeContext.MORNING,
@@ -473,10 +473,13 @@ class DialogueManager:
         self.user_profile = self.user_manager.get_user_profile(self.user_id)
         # 刷新员工画像(不一定需要)
         self.staff_profile = self.user_manager.get_staff_profile(self.staff_id)
+        # 员工画像添加前缀,避免冲突,实现Coze Prompt模板的平滑升级
+        legacy_staff_profile = {}
+        for key in self.staff_profile:
+            legacy_staff_profile[f'agent_{key}'] = self.staff_profile[key]
 
         current_datetime = datetime.now()
         context = {
-            "user_profile": self.user_profile,
             "current_state": self.current_state.name,
             "previous_state": self.previous_state.name,
             "current_time_period": time_context.description,
@@ -487,8 +490,9 @@ class DialogueManager:
             "last_interaction_interval": self._get_hours_since_last_interaction(2),
             "if_first_interaction": True if self.previous_state == DialogueState.INITIALIZED else False,
             "if_active_greeting": False if user_message else True,
+            "formatted_staff_profile": prompt_utils.format_agent_profile(self.staff_profile),
             **self.user_profile,
-            **self.staff_profile
+            **legacy_staff_profile
         }
 
         # 获取长期记忆

+ 3 - 0
pqai_agent/exceptions.py

@@ -0,0 +1,3 @@
+class NoRetryException(Exception):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)

+ 50 - 14
pqai_agent/history_dialogue_service.py

@@ -1,8 +1,12 @@
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 # vim:fenc=utf-8
+from typing import Dict, List
 
 import requests
+from pymysql.cursors import DictCursor
+
+from pqai_agent.database import MySQLManager
 from pqai_agent.logging_service import logger
 import time
 
@@ -14,19 +18,11 @@ class HistoryDialogueService:
     def __init__(self, base_url: str):
         self.base_url = base_url
 
-    def get_dialogue_history(self, staff_id: str, user_id: str, recent_minutes: int = 1440):
-        time_begin = int(time.time() * 1000) - recent_minutes * 60 * 1000
-        url = f"{self.base_url}?sender={staff_id}&receiver={user_id}&time={time_begin}"
-        response = requests.post(url, headers={
-            'Content-Type': 'application/json'
-        })
-        if response.status_code != 200:
-            raise Exception("Request error [{}]: {}".format(response.status_code, response.text))
-        data = response.json()
-        if not data.get('success', False):
-            raise Exception("Error in response: {}".format(data.get('message', 'no message returned')))
-        data = data.get('data', [])
+    @staticmethod
+    def convert_raw_records_to_base_messages(data: List[Dict], staff_id: str, user_id: str, reverse: bool = False) -> List[Dict]:
         ret = []
+        if reverse:
+            data = reversed(data)
         for record in data:
             sender = record.get('sender')
             if sender == user_id:
@@ -47,12 +43,52 @@ class HistoryDialogueService:
                 logger.warning(f"staff[{staff_id}], user[{user_id}]: skip unsupported message type {message['type']}")
                 continue
             ret.append(message)
+        return ret
+
+    def get_dialogue_history(self, staff_id: str, user_id: str, recent_minutes: int = 1440):
+        time_begin = int(time.time() * 1000) - recent_minutes * 60 * 1000
+        url = f"{self.base_url}?sender={staff_id}&receiver={user_id}&time={time_begin}"
+        response = requests.post(url, headers={
+            'Content-Type': 'application/json'
+        })
+        if response.status_code != 200:
+            raise Exception("Request error [{}]: {}".format(response.status_code, response.text))
+        data = response.json()
+        if not data.get('success', False):
+            raise Exception("Error in response: {}".format(data.get('message', 'no message returned')))
+        data = data.get('data', [])
+        ret = self.convert_raw_records_to_base_messages(data, staff_id, user_id)
         ret = sorted(ret, key=lambda x: x['timestamp'])
         return ret
 
 
+class HistoryDialogueDatabase:
+    PRIVATE_ROOM_ID_FORMAT = 'private:%s:%s'
+
+    def __init__(self, db_config, table_name: str = 'qywx_chat_history'):
+        self.db = MySQLManager(db_config)
+        self.table_name = table_name
+
+    def get_dialogue_history_backward(self, staff_id: str, user_id: str, end_timestamp_ms: int, limit: int = 100):
+        if staff_id < user_id:
+            room_id = self.PRIVATE_ROOM_ID_FORMAT % (staff_id, user_id)
+        else:
+            room_id = self.PRIVATE_ROOM_ID_FORMAT % (user_id, staff_id)
+        sql = f"SELECT sender, receiver, msg_type, content, sendtime as sendTime FROM {self.table_name} " \
+                "WHERE roomid = %s AND sendtime <= %s ORDER BY sendtime DESC LIMIT %s"
+        data = self.db.select(sql, DictCursor, (room_id, end_timestamp_ms, limit))
+        if not data:
+            return []
+        ret = HistoryDialogueService.convert_raw_records_to_base_messages(data, staff_id, user_id, reverse=True)
+        return ret
+
+
+
 if __name__ == '__main__':
     api_url = configs.get()['storage']['history_dialogue']['api_base_url']
     service = HistoryDialogueService(api_url)
-    resp = service.get_dialogue_history(staff_id='1688854492669990', user_id='7881301263964433')
-    print(resp)
+    resp = service.get_dialogue_history(staff_id='1688857241615085', user_id='7881299616070168', recent_minutes=5*1440)
+    print(resp)
+    user_db_config = configs.get()['storage']['user']['mysql']
+    db = HistoryDialogueDatabase(user_db_config)
+    # print(db.get_dialogue_history_backward('1688854492669990', '7881301263964433', 1747397155000))

+ 3 - 2
pqai_agent/message_queue_backend.py

@@ -55,7 +55,8 @@ class AliyunRocketMQQueueBackend(MessageQueueBackend):
                  has_consumer: bool = False, has_producer: bool = False,
                  group_id: Optional[str] = None,
                  ak:Optional[str] = None, sk: Optional[str] = None,
-                 topic_type: Optional[str] = None):
+                 topic_type: Optional[str] = None,
+                 await_duration: int = 20):
         if not has_consumer and not has_producer:
             raise ValueError("At least one of has_consumer or has_producer must be True.")
         self.has_consumer = has_consumer
@@ -66,7 +67,7 @@ class AliyunRocketMQQueueBackend(MessageQueueBackend):
         self.topic = topic
         self.group_id = group_id
         if has_consumer:
-            self.consumer = SimpleConsumer(mq_config, group_id)
+            self.consumer = SimpleConsumer(mq_config, group_id, await_duration=await_duration)
             self.consumer.startup()
             self.consumer.subscribe(self.topic)
         if has_producer:

+ 184 - 0
pqai_agent/push_service.py

@@ -0,0 +1,184 @@
+import json
+import time
+import traceback
+import uuid
+from datetime import datetime
+from enum import Enum
+from concurrent.futures import ThreadPoolExecutor
+from threading import Thread
+from typing import Optional, Dict
+
+import rocketmq
+from rocketmq import ClientConfiguration, Credentials, SimpleConsumer, FilterExpression
+
+from pqai_agent import configs
+from pqai_agent.agents.message_push_agent import MessagePushAgent, DummyMessagePushAgent
+from pqai_agent.configs import apollo_config
+from pqai_agent.logging_service import logger
+from pqai_agent.message import MessageType
+
+
+class TaskType(Enum):
+    GENERATE = "generate"
+    SEND = "send"
+
+def generate_task_rmq_message(topic: str, staff_id: str, user_id: str, task_type: TaskType, content: Optional[str] = None) -> rocketmq.Message:
+    msg = rocketmq.Message()
+    msg.topic = topic
+    msg.body = json.dumps({
+        'staff_id': staff_id,
+        'user_id': user_id,
+        'task_type': task_type.value,
+        # FIXME: 需要支持多模态消息
+        'content': content or '',
+        'timestamp': int(time.time() * 1000),
+    }, ensure_ascii=False).encode('utf-8')
+    msg.tag = task_type.value
+    return msg
+
+class PushScanThread:
+    # PushScanThread实际可以是AgentService的一个函数,从AgentService中独立的主要考虑因素为Push后续可能有拆分和扩展
+    def __init__(self, staff_id: str, agent_service: 'AgentService', mq_topic: str, mq_producer: rocketmq.Producer):
+        self.staff_id = staff_id
+        # 需要大量使用AgentService内部的成员
+        self.service = agent_service
+        self.rmq_topic = mq_topic
+        self.rmq_producer = mq_producer
+
+    def run(self):
+        white_list_tags = set(apollo_config.get_json_value('agent_initiate_whitelist_tags', []))
+        first_initiate_tags = set(apollo_config.get_json_value('agent_first_initiate_whitelist_tags', []))
+        # 合并白名单,减少配置成本
+        white_list_tags.update(first_initiate_tags)
+        for staff_user in self.service.user_relation_manager.list_staff_users(staff_id=self.staff_id):
+            staff_id = staff_user['staff_id']
+            user_id = staff_user['user_id']
+            agent = self.service.get_agent_instance(staff_id, user_id)
+            should_initiate = agent.should_initiate_conversation()
+            user_tags = self.service.user_relation_manager.get_user_tags(user_id)
+
+            if configs.get_env() != 'dev' and not white_list_tags.intersection(user_tags):
+                should_initiate = False
+            if should_initiate:
+                logger.info(f"user[{user_id}], tags{user_tags}: generate a generation task for conversation initiation")
+                rmq_msg = generate_task_rmq_message(self.rmq_topic, staff_id, user_id, TaskType.GENERATE)
+                self.rmq_producer.send(rmq_msg)
+            else:
+                logger.debug(f"user[{user_id}], do not initiate conversation")
+
+
+class PushTaskWorkerPool:
+    def __init__(self, agent_service: 'AgentService', mq_topic: str,
+                 mq_consumer: rocketmq.SimpleConsumer, mq_producer: rocketmq.Producer):
+        self.agent_service = agent_service
+        max_workers = configs.get()['system'].get('push_task_workers', 5)
+        self.generate_executor = ThreadPoolExecutor(max_workers=max_workers)
+        self.send_executors = {}
+        self.rmq_topic = mq_topic
+        self.consumer = mq_consumer
+        self.producer = mq_producer
+        self.loop_thread = None
+        self.is_generator_running = True
+        self.generate_send_done = False # set by wait_to_finish
+        self.no_more_generate_task = False # set by self
+
+    def start(self):
+        self.loop_thread = Thread(target=self.process_push_tasks)
+        self.loop_thread.start()
+
+    def process_push_tasks(self):
+        # RMQ consumer疑似有bug,创建后立即消费可能报NPE
+        time.sleep(1)
+        while True:
+            msgs = self.consumer.receive(1, 300)
+            if not msgs:
+                # 没有生成任务在执行且没有消息,才可退出
+                if self.generate_send_done:
+                    if not self.no_more_generate_task:
+                        logger.debug("no message received, there should be no more generate task")
+                        self.no_more_generate_task = True
+                        continue
+                    else:
+                        if self.is_generator_running:
+                            logger.debug("Waiting for generator threads to finish")
+                            continue
+                        else:
+                            break
+                else:
+                    continue
+            msg = msgs[0]
+            task = json.loads(msg.body.decode('utf-8'))
+            msg_time = datetime.fromtimestamp(task['timestamp'] / 1000).strftime("%Y-%m-%d %H:%M:%S")
+            logger.debug(f"recv message:{msg_time} - {task}")
+            if task['task_type'] == TaskType.GENERATE.value:
+                self.generate_executor.submit(self.handle_generate_task, task, msg)
+            elif task['task_type'] == TaskType.SEND.value:
+                staff_id = task['staff_id']
+                if staff_id not in self.send_executors:
+                    self.send_executors[staff_id] = ThreadPoolExecutor(max_workers=1)
+                self.send_executors[staff_id].submit(self.handle_send_task, task, msg)
+            else:
+                logger.error(f"Unknown task type: {task['task_type']}")
+                self.consumer.ack(msg)
+        logger.info("PushGenerateWorkerPool stopped")
+
+    def wait_to_finish(self):
+        self.generate_send_done = True
+        while not self.no_more_generate_task:
+            #FIXME(zhoutian): condition variable should be used to replace time sleep
+            time.sleep(1)
+        self.generate_executor.shutdown(wait=True)
+        self.is_generator_running = False
+        self.loop_thread.join()
+
+    def handle_send_task(self, task: Dict, msg: rocketmq.Message):
+        try:
+            staff_id = task['staff_id']
+            user_id = task['user_id']
+            agent = self.agent_service.get_agent_instance(staff_id, user_id)
+            # 二次校验是否需要发送
+            if not agent.should_initiate_conversation():
+                logger.debug(f"user[{user_id}], do not initiate conversation")
+                self.consumer.ack(msg)
+                return
+            content = task['content']
+            recent_dialogue = agent.dialogue_history[-10:]
+            agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
+            # FIXME(zhoutian): 不应该再由agent控制,或者agent和API共享同一配置
+            if len(recent_dialogue) < 2 or staff_id not in agent_voice_whitelist:
+                message_type = MessageType.TEXT
+            else:
+                message_type = self.agent_service.response_type_detector.detect_type(
+                    recent_dialogue, content, enable_random=True)
+            response = agent.generate_response(content)
+            if response:
+                self.agent_service.send_response(staff_id, user_id, response, message_type, skip_check=True)
+            else:
+                logger.debug(f"agent[{staff_id}] generate empty response")
+            self.consumer.ack(msg)
+        except Exception as e:
+            fmt_exc = traceback.format_exc()
+            logger.error(f"Error processing message sending: {e}, {fmt_exc}")
+            self.consumer.ack(msg)
+
+    def handle_generate_task(self, task: Dict, msg: rocketmq.Message):
+        try:
+            staff_id = task['staff_id']
+            user_id = task['user_id']
+            main_agent = self.agent_service.get_agent_instance(staff_id, user_id)
+            push_agent = MessagePushAgent()
+            message_to_user = push_agent.generate_message(
+                context=main_agent.get_prompt_context(None),
+                dialogue_history=self.agent_service.history_dialogue_db.get_dialogue_history_backward(
+                    staff_id, user_id, main_agent.last_interaction_time, limit=100
+                )
+            )
+            rmq_message = generate_task_rmq_message(self.rmq_topic, staff_id, user_id, TaskType.SEND, message_to_user)
+            logger.debug(f"send message: {rmq_message.body.decode('utf-8')}")
+            self.producer.send(rmq_message)
+            self.consumer.ack(msg)
+        except Exception as e:
+            fmt_exc = traceback.format_exc()
+            logger.error(f"Error processing message generation: {e}, {fmt_exc}")
+            # FIXME: 是否需要ACK
+            self.consumer.ack(msg)

+ 2 - 0
pqai_agent/response_type_detector.py

@@ -75,6 +75,8 @@ class ResponseTypeDetector:
 
     @staticmethod
     def if_message_suitable_for_voice(message):
+        if not message:
+            return False
         # 使用语音的文字不适合过长
         if len(message) > 50:
             return False

+ 67 - 12
pqai_agent/user_manager.py

@@ -77,7 +77,7 @@ class UserRelationManager(abc.ABC):
         pass
 
     @abc.abstractmethod
-    def list_staff_users(self) -> List[Dict]:
+    def list_staff_users(self, staff_id: str = None, tag_id: int = None) -> List[Dict]:
         pass
 
     @abc.abstractmethod
@@ -124,13 +124,16 @@ class LocalUserManager(UserManager):
         return user_ids
 
     def get_staff_profile(self, staff_id) -> Dict:
-        # for test only
-        return {
-            'agent_name': '小芳',
-            'agent_gender': '女',
-            'agent_age': 30,
-            'agent_region': '北京'
-        }
+        try:
+            with open(f"user_profiles/{staff_id}.json", "r", encoding="utf-8") as f:
+                profile = json.load(f)
+            entry_added = False
+            if entry_added:
+                self.save_user_profile(staff_id, profile)
+            return profile
+        except Exception as e:
+            logger.error("staff profile not found: {}".format(e))
+            return {}
 
     def list_users(self, **kwargs) -> List[Dict]:
         pass
@@ -190,7 +193,10 @@ class MySQLUserManager(UserManager):
     def get_staff_profile(self, staff_id) -> Dict:
         if not self.staff_table:
             raise Exception("staff_table is not set")
-        sql = f"SELECT agent_name, agent_gender, agent_age, agent_region " \
+        return self.get_staff_profile_v3(staff_id)
+
+    def get_staff_profile_v1(self, staff_id) -> Dict:
+        sql = f"SELECT agent_name, agent_gender, agent_age, agent_region, agent_profile " \
               f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
         data = self.db.select(sql, pymysql.cursors.DictCursor)
         if not data:
@@ -202,6 +208,49 @@ class MySQLUserManager(UserManager):
         profile['agent_gender'] = gender_map[profile['agent_gender']]
         return profile
 
+    def get_staff_profile_v2(self, staff_id) -> Dict:
+        sql = f"SELECT agent_name as name, agent_gender as gender, agent_age as age, agent_region as region, agent_profile " \
+              f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
+        data = self.db.select(sql, pymysql.cursors.DictCursor)
+        if not data:
+            logger.error(f"staff[{staff_id}] not found")
+            return {}
+        profile = data[0]
+        # 转换性别格式
+        gender_map = {0: '未知', 1: '男', 2: '女', None: '未知'}
+        profile['gender'] = gender_map[profile['gender']]
+
+        # 合并JSON字段(新版本)数据
+        if profile['agent_profile']:
+            detail_profile = json.loads(profile['agent_profile'])
+            profile.update(detail_profile)
+
+        # 去除原始字段
+        profile.pop('agent_profile', None)
+        return profile
+
+    def get_staff_profile_v3(self, staff_id) -> Dict:
+        sql = f"SELECT agent_profile " \
+              f"FROM {self.staff_table} WHERE third_party_user_id = '{staff_id}'"
+        data = self.db.select(sql)
+        if not data:
+            logger.error(f"staff[{staff_id}] not found")
+            return {}
+        profile_str = data[0][0]
+        if not profile_str:
+            return {}
+        profile = json.loads(profile_str)
+        return profile
+
+    def save_staff_profile(self, staff_id: str, profile: Dict):
+        # 正常情况下不应该有此操作
+        if not self.staff_table:
+            raise Exception("staff_table is not set")
+        if not staff_id:
+            raise Exception("Invalid staff_id: {}".format(staff_id))
+        sql = f"UPDATE {self.staff_table} SET agent_profile = %s WHERE third_party_user_id = '{staff_id}'"
+        self.db.execute(sql, (json.dumps(profile),))
+
     def list_users(self, **kwargs) -> List[Dict]:
         user_union_id = kwargs.get('user_union_id', None)
         user_name = kwargs.get('user_name', None)
@@ -221,15 +270,21 @@ class LocalUserRelationManager(UserRelationManager):
 
     def list_staffs(self):
         return [
-            {"third_party_user_id": 0, "name": "x", "wxid": "x", "agent_name": "小芳"}
+            {"third_party_user_id": '1688855931724582', "name": "", "wxid": "ShengHuoLeQu", "agent_name": "小芳"}
         ]
 
     def list_users(self, staff_id: str, page: int = 1, page_size: int = 100):
         return []
 
-    def list_staff_users(self):
+    def list_staff_users(self, staff_id: str = None, tag_id: int = None):
+        user_ids = ['7881299453089278', '7881299453132630', '7881299454186909', '7881299455103430', '7881299455173476',
+                    '7881299456216398', '7881299457990953', '7881299461167644', '7881299463002136', '7881299464081604',
+                    '7881299465121735', '7881299465998082', '7881299466221881', '7881299467152300', '7881299470051791',
+                    '7881299470112816', '7881299471149567', '7881299471168030', '7881299471277650', '7881299473321703']
+        user_ids = user_ids[:5]
         return [
-            {"staff_id": "1688854492669990", "user_id": "7881299670930896"}
+            {"staff_id": "1688855931724582", "user_id": "7881299670930896"},
+            *[{"staff_id": "1688855931724582", "user_id": user_id} for user_id in user_ids]
         ]
 
     def get_user_tags(self, user_id: str):

+ 0 - 0
pqai_agent/utils/__init__.py


+ 23 - 0
pqai_agent/utils/prompt_utils.py

@@ -0,0 +1,23 @@
+from typing import Dict
+
+
+def format_agent_profile(profile: Dict) -> str:
+    fields = [
+        ('name', '名字'),
+        ('gender', '性别'),
+        ('age', '年龄'),
+        ('region', '所在地'),
+        ('previous_location', '之前所在地'),
+        ('education', '学历'),
+        ('occupation', '职业'),
+        ('work_experience', '工作经历'),
+        ('family_members', '家庭成员'),
+        ('family_occupation', '家庭成员职业')
+    ]
+    strings_to_join = []
+    for field in fields:
+        if not profile.get(field[0], None):
+            continue
+        cur_string = f"- {field[1]}:{profile[field[0]]}"
+        strings_to_join.append(cur_string)
+    return "\n".join(strings_to_join)

+ 102 - 0
pqai_agent_server/agent_server.py

@@ -0,0 +1,102 @@
+import logging
+import sys
+import time
+
+from pqai_agent import configs, logging_service
+from pqai_agent.agent_service import AgentService
+from pqai_agent.chat_service import ChatServiceType
+from pqai_agent.logging_service import logger
+from pqai_agent.message import MessageType, Message, MessageChannel
+from pqai_agent.message_queue_backend import AliyunRocketMQQueueBackend, MemoryQueueBackend
+from pqai_agent.push_service import PushTaskWorkerPool, PushScanThread
+from pqai_agent.user_manager import LocalUserManager, LocalUserRelationManager, MySQLUserManager, \
+    MySQLUserRelationManager
+
+if __name__ == "__main__":
+    config = configs.get()
+    logging_service.setup_root_logger()
+    logger.warning("current env: {}".format(configs.get_env()))
+    scheduler_logger = logging.getLogger('apscheduler')
+    scheduler_logger.setLevel(logging.WARNING)
+
+    use_aliyun_mq = config['debug_flags']['use_aliyun_mq']
+
+    # 初始化不同队列的后端
+    if use_aliyun_mq:
+        receive_queue = AliyunRocketMQQueueBackend(
+            config['mq']['endpoints'],
+            config['mq']['instance_id'],
+            config['mq']['receive_topic'],
+            has_consumer=True, has_producer=True,
+            group_id=config['mq']['receive_group'],
+            topic_type='FIFO', await_duration=10
+        )
+        send_queue = AliyunRocketMQQueueBackend(
+            config['mq']['endpoints'],
+            config['mq']['instance_id'],
+            config['mq']['send_topic'],
+            has_consumer=False, has_producer=True,
+            topic_type='FIFO'
+        )
+    else:
+        receive_queue = MemoryQueueBackend()
+        send_queue = MemoryQueueBackend()
+    human_queue = MemoryQueueBackend()
+
+    # 初始化用户管理服务
+    # FIXME(zhoutian): 如果不使用MySQL,此数据库配置非必须
+    user_db_config = config['storage']['user']
+    staff_db_config = config['storage']['staff']
+    wecom_db_config = config['storage']['user_relation']
+    if config['debug_flags'].get('use_local_user_storage', False):
+        user_manager = LocalUserManager()
+        user_relation_manager = LocalUserRelationManager()
+    else:
+        user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
+        user_relation_manager = MySQLUserRelationManager(
+            user_db_config['mysql'], wecom_db_config['mysql'],
+            config['storage']['staff']['table'],
+            user_db_config['table'],
+            wecom_db_config['table']['staff'],
+            wecom_db_config['table']['relation'],
+            wecom_db_config['table']['user']
+        )
+
+    # 创建Agent服务
+    service = AgentService(
+        receive_backend=receive_queue,
+        send_backend=send_queue,
+        human_backend=human_queue,
+        user_manager=user_manager,
+        user_relation_manager=user_relation_manager,
+        chat_service_type=ChatServiceType.COZE_CHAT
+    )
+
+    if not config['debug_flags'].get('console_input', False):
+        service.start(blocking=True)
+        sys.exit(0)
+    else:
+        service.start()
+
+    message_id = 0
+    while service.running:
+        print("Input next message: ")
+        text = sys.stdin.readline().strip()
+        if not text:
+            continue
+        message_id += 1
+        sender = '7881301903997433'
+        receiver = '1688855931724582'
+        if text in (MessageType.AGGREGATION_TRIGGER.name,
+                    MessageType.HUMAN_INTERVENTION_END.name):
+            message = Message.build(
+                MessageType.__members__.get(text),
+                MessageChannel.CORP_WECHAT,
+                sender, receiver, None, int(time.time() * 1000))
+        else:
+            message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
+                                    sender,receiver, text, int(time.time() * 1000)
+                                    )
+        message.msgId = message_id
+        receive_queue.produce(message)
+        time.sleep(0.1)

+ 25 - 0
pqai_agent_server/api_server.py

@@ -158,6 +158,31 @@ def run_prompt():
         return wrap_response(500, msg="Error: {}".format(e))
 
 
+@app.route("/api/healthCheck", methods=["GET"])
+def health_check():
+    return wrap_response(200, msg="OK")
+
+
+@app.route("/api/getStaffSessionSummary", methods=["GET"])
+def get_staff_session_summary():
+    return wrap_response(200, msg="OK")
+
+
+@app.route("/api/getStaffSessionList", methods=["GET"])
+def get_staff_session_list():
+    return wrap_response(200, msg="OK")
+
+
+@app.route("api/getConversationList", methods=["GET"])
+def get_conversation_list():
+    return wrap_response(200, msg="OK")
+
+
+@app.route("api/sendMessage", methods=["POST"])
+def send_message():
+    return wrap_response(200, msg="OK")
+
+
 @app.errorhandler(werkzeug.exceptions.BadRequest)
 def handle_bad_request(e):
     logger.error(e)

+ 5 - 1
requirements.txt

@@ -53,4 +53,8 @@ docstring_parser~=0.16
 pyapollos~=0.1.5
 Werkzeug~=3.1.3
 Flask~=3.1.0
-jsonschema~=4.23.0
+jsonschema~=4.23.0
+pqai_agent~=0.1.0
+numpy~=2.2.5
+pillow~=11.2.1
+json5~=0.12.0

+ 6 - 6
tests/unit_test.py

@@ -52,7 +52,7 @@ def test_env():
 
 def test_agent_state_change(test_env):
     service, _ = test_env
-    agent = service._get_agent_instance('staff_id_0', 'user_id_0')
+    agent = service.get_agent_instance('staff_id_0', 'user_id_0')
     assert agent.current_state == DialogueState.INITIALIZED
     assert agent.previous_state == DialogueState.INITIALIZED
 
@@ -110,7 +110,7 @@ def test_response_sanitization(test_env):
 def test_normal_conversation_flow(test_env):
     """测试正常对话流程"""
     service, queues = test_env
-    service._get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
+    service.get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
 
     # 准备测试消息
     test_msg = Message.build(
@@ -132,7 +132,7 @@ def test_normal_conversation_flow(test_env):
 def test_aggregated_conversation_flow(test_env):
     """测试聚合对话流程"""
     service, queues = test_env
-    service._get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 1
+    service.get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 1
 
     # 准备测试消息
     ts_begin = int(time.time() * 1000)
@@ -175,7 +175,7 @@ def test_aggregated_conversation_flow(test_env):
 def test_human_intervention_trigger(test_env):
     """测试触发人工干预"""
     service, queues = test_env
-    service._get_agent_instance('staff_id_0',"user_id_0").message_aggregation_sec = 0
+    service.get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
 
     # 准备需要人工干预的消息
     test_msg = Message.build(
@@ -201,11 +201,11 @@ def test_human_intervention_trigger(test_env):
 def test_initiative_conversation(test_env):
     """测试主动发起对话"""
     service, queues = test_env
-    service._get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
+    service.get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
     service._call_chat_api = Mock(return_value="主动发起模拟消息")
 
     # 设置Agent需要主动发起对话
-    agent = service._get_agent_instance('staff_id_0', "user_id_0")
+    agent = service.get_agent_instance('staff_id_0', "user_id_0")
     agent.should_initiate_conversation = Mock(return_value=(True, MagicMock()))
     # 发起对话有时间限制
     agent.get_time_context = Mock(return_value=TimeContext.MORNING)