#! /usr/bin/env python # -*- coding: utf-8 -*- # vim:fenc=utf-8 import sys import time from typing import Dict, List, Tuple, Any, Optional import logging from datetime import datetime, timedelta import traceback import apscheduler.triggers.cron from apscheduler.schedulers.background import BackgroundScheduler import chat_service import configs import logging_service from logging_service import logger from chat_service import CozeChat, ChatServiceType from dialogue_manager import DialogueManager, DialogueState, DialogueStateCache from user_manager import UserManager, LocalUserManager, MySQLUserManager, MySQLUserRelationManager, UserRelationManager from openai import OpenAI from message_queue_backend import MessageQueueBackend, MemoryQueueBackend, AliyunRocketMQQueueBackend from user_profile_extractor import UserProfileExtractor import threading from message import MessageType, Message, MessageChannel class AgentService: def __init__( self, receive_backend: MessageQueueBackend, send_backend: MessageQueueBackend, human_backend: MessageQueueBackend, user_manager: UserManager, user_relation_manager: UserRelationManager, chat_service_type: ChatServiceType = ChatServiceType.OPENAI_COMPATIBLE ): self.receive_queue = receive_backend self.send_queue = send_backend self.human_queue = human_backend # 核心服务模块 self.agent_state_cache = DialogueStateCache() self.user_manager = user_manager self.user_relation_manager = user_relation_manager self.user_profile_extractor = UserProfileExtractor() self.agent_registry: Dict[str, DialogueManager] = {} self.llm_client = OpenAI( api_key=chat_service.VOLCENGINE_API_TOKEN, base_url=chat_service.VOLCENGINE_BASE_URL ) # DeepSeek on Volces self.model_name = chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3 coze_config = configs.get()['chat_api']['coze'] coze_oauth_app = CozeChat.get_oauth_app( coze_config['oauth_client_id'], coze_config['private_key_path'], str(coze_config['public_key_id']), account_id=coze_config.get('account_id', None) ) self.coze_client = CozeChat( base_url=chat_service.COZE_CN_BASE_URL, auth_app=coze_oauth_app ) self.chat_service_type = chat_service_type # 定时任务调度器 self.scheduler = BackgroundScheduler() self.scheduler.start() def setup_initiative_conversations(self, schedule_params: Optional[Dict] = None): if not schedule_params: schedule_params = {'hour': '8,16,20'} self.scheduler.add_job( self._check_initiative_conversations, apscheduler.triggers.cron.CronTrigger(**schedule_params) ) def _get_agent_instance(self, staff_id: str, user_id: str) -> DialogueManager: """获取Agent实例""" agent_key = 'agent_{}_{}'.format(staff_id, user_id) if agent_key not in self.agent_registry: self.agent_registry[agent_key] = DialogueManager( staff_id, user_id, self.user_manager, self.agent_state_cache) return self.agent_registry[agent_key] def process_messages(self): """持续处理接收队列消息""" while True: message = self.receive_queue.consume() if message: try: self.process_single_message(message) self.receive_queue.ack(message) except Exception as e: logger.error("Error processing message: {}".format(e)) traceback.print_exc() time.sleep(1) def _update_user_profile(self, user_id, user_profile, message: str): profile_to_update = self.user_profile_extractor.extract_profile_info(user_profile, message) if not profile_to_update: logger.debug("user_id: {}, no profile info extracted".format(user_id)) return logger.warning("update user profile: {}".format(profile_to_update)) merged_profile = self.user_profile_extractor.merge_profile_info(user_profile, profile_to_update) self.user_manager.save_user_profile(user_id, merged_profile) return merged_profile def _schedule_aggregation_trigger(self, staff_id: str, user_id: str, delay_sec: int): logger.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec)) message_ts = int((time.time() + delay_sec) * 1000) message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts) # 系统消息使用特定的msgId,无实际意义 message.msgId = -MessageType.AGGREGATION_TRIGGER.value self.scheduler.add_job(lambda: self.receive_queue.produce(message), 'date', run_date=datetime.now() + timedelta(seconds=delay_sec)) def process_single_message(self, message: Message): user_id = message.sender staff_id = message.receiver # 获取用户信息和Agent实例 user_profile = self.user_manager.get_user_profile(user_id) agent = self._get_agent_instance(staff_id, user_id) # 更新对话状态 logger.debug("process message: {}".format(message)) need_response, message_text = agent.update_state(message) logger.debug("user: {}, next state: {}".format(user_id, agent.current_state)) # 根据状态路由消息 if agent.is_in_human_intervention(): self._route_to_human_intervention(user_id, message) elif agent.current_state == DialogueState.MESSAGE_AGGREGATING: if message.type != MessageType.AGGREGATION_TRIGGER: # 产生一个触发器,但是不能由触发器递归产生 logger.debug("user: {}, waiting next message for aggregation".format(user_id)) self._schedule_aggregation_trigger(staff_id, user_id, agent.message_aggregation_sec) return elif need_response: # 先更新用户画像再处理回复 self._update_user_profile(user_id, user_profile, message_text) self._get_chat_response(user_id, agent, message_text) else: logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response") def _route_to_human_intervention(self, user_id: str, origin_message: Message): """路由到人工干预""" self.human_queue.produce(Message.build( MessageType.TEXT, origin_message.channel, origin_message.sender, origin_message.receiver, "用户对话需人工介入,用户名:{}".format(user_id), int(time.time() * 1000) )) def _check_initiative_conversations(self): """定时检查主动发起对话""" for staff_user in self.user_relation_manager.list_staff_users(): staff_id = staff_user['staff_id'] user_id = staff_user['user_id'] agent = self._get_agent_instance(staff_id, user_id) should_initiate = agent.should_initiate_conversation() if should_initiate: logger.warning("user: {}, initiate conversation".format(user_id)) self._get_chat_response(user_id, agent, None) else: logger.debug("user: {}, do not initiate conversation".format(user_id)) def _get_chat_response(self, user_id: str, agent: DialogueManager, user_message: Optional[str]): """处理LLM响应""" chat_config = agent.build_chat_configuration(user_message, self.chat_service_type) logger.debug(chat_config) # FIXME(zhoutian): 临时处理去除头尾的空格 chat_response = self._call_chat_api(chat_config).strip() if response := agent.generate_response(chat_response): logger.warning(f"staff[{agent.staff_id}] user[{user_id}]: response: {response}") current_ts = int(time.time() * 1000) # FIXME(zhoutian) # 测试期间临时逻辑,只发送特定的用户 if agent.staff_id not in set(['1688854492669990']): logger.warning(f"skip message from sender [{agent.staff_id}]") return self.send_queue.produce( Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT, agent.staff_id, user_id, response, current_ts) ) else: logger.warning(f"staff[{agent.staff_id}] user[{user_id}]: no response generated") def _call_chat_api(self, chat_config: Dict) -> str: if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False): return 'LLM模拟回复 {}'.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) if self.chat_service_type == ChatServiceType.OPENAI_COMPATIBLE: chat_completion = self.llm_client.chat.completions.create( messages=chat_config['messages'], model=self.model_name, ) response = chat_completion.choices[0].message.content elif self.chat_service_type == ChatServiceType.COZE_CHAT: bot_user_id = 'dev_user' response = self.coze_client.create( chat_config['bot_id'], bot_user_id, chat_config['messages'], chat_config['custom_variables'] ) else: raise Exception('Unsupported chat service type: {}'.format(self.chat_service_type)) return response if __name__ == "__main__": config = configs.get() logging_service.setup_root_logger() logger.warning("current env: {}".format(configs.get_env())) scheduler_logger = logging.getLogger('apscheduler') scheduler_logger.setLevel(logging.WARNING) use_aliyun_mq = config['debug_flags']['use_aliyun_mq'] # 初始化不同队列的后端 if use_aliyun_mq: receive_queue = AliyunRocketMQQueueBackend( config['mq']['endpoints'], config['mq']['instance_id'], config['mq']['receive_topic'], has_consumer=True, has_producer=True, group_id=config['mq']['receive_group'] ) send_queue = AliyunRocketMQQueueBackend( config['mq']['endpoints'], config['mq']['instance_id'], config['mq']['send_topic'], has_consumer=False, has_producer=True ) else: receive_queue = MemoryQueueBackend() send_queue = MemoryQueueBackend() human_queue = MemoryQueueBackend() # 初始化用户管理服务 # FIXME(zhoutian): 如果不使用MySQL,此数据库配置非必须 user_db_config = config['storage']['user'] staff_db_config = config['storage']['staff'] if config['debug_flags'].get('use_local_user_storage', False): user_manager = LocalUserManager() else: user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table']) wecom_db_config = config['storage']['user_relation'] user_relation_manager = MySQLUserRelationManager( user_db_config['mysql'], wecom_db_config['mysql'], config['storage']['staff']['table'], user_db_config['table'], wecom_db_config['table']['staff'], wecom_db_config['table']['relation'], wecom_db_config['table']['user'] ) # 创建Agent服务 service = AgentService( receive_backend=receive_queue, send_backend=send_queue, human_backend=human_queue, user_manager=user_manager, user_relation_manager=user_relation_manager, chat_service_type=ChatServiceType.COZE_CHAT ) # 只有企微场景需要主动发起 if not config['debug_flags'].get('disable_active_conversation', False): service.setup_initiative_conversations({'second': '5,35'}) process_thread = threading.Thread(target=service.process_messages) process_thread.start() if not config['debug_flags'].get('console_input', False): process_thread.join() sys.exit(0) message_id = 0 while True: print("Input next message: ") text = sys.stdin.readline().strip() if not text: continue message_id += 1 sender = '7881302581935903' receiver = '1688854492669990' if text == MessageType.AGGREGATION_TRIGGER.name: message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT, sender, receiver, None, int(time.time() * 1000)) else: message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT, sender,receiver, text, int(time.time() * 1000) ) message.msgId = message_id receive_queue.produce(message) time.sleep(0.1) process_thread.join()