|
@@ -6,7 +6,7 @@ import re
|
|
|
import signal
|
|
|
import sys
|
|
|
import time
|
|
|
-from typing import Dict, List, Optional
|
|
|
+from typing import Dict, List, Optional, Union
|
|
|
import logging
|
|
|
from datetime import datetime, timedelta
|
|
|
import threading
|
|
@@ -18,6 +18,7 @@ from apscheduler.schedulers.background import BackgroundScheduler
|
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
|
|
from pqai_agent import configs
|
|
|
+from pqai_agent.agents.message_reply_agent import MessageReplyAgent
|
|
|
from pqai_agent.configs import apollo_config
|
|
|
from pqai_agent.exceptions import NoRetryException
|
|
|
from pqai_agent.logging_service import logger
|
|
@@ -31,7 +32,7 @@ from pqai_agent.response_type_detector import ResponseTypeDetector
|
|
|
from pqai_agent.user_manager import UserManager, UserRelationManager
|
|
|
from pqai_agent.message_queue_backend import MessageQueueBackend, AliyunRocketMQQueueBackend
|
|
|
from pqai_agent.user_profile_extractor import UserProfileExtractor
|
|
|
-from pqai_agent.message import MessageType, Message, MessageChannel
|
|
|
+from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
|
|
|
from pqai_agent.utils.db_utils import create_sql_engine
|
|
|
|
|
|
|
|
@@ -136,7 +137,7 @@ class AgentService:
|
|
|
time.sleep(1)
|
|
|
logger.info("Scheduler event processing thread exit")
|
|
|
|
|
|
- def process_scheduler_event(self, msg: Message):
|
|
|
+ def process_scheduler_event(self, msg: MqMessage):
|
|
|
if msg.type == MessageType.AGGREGATION_TRIGGER:
|
|
|
# 延迟触发的消息,需放入接收队列以驱动Agent运转
|
|
|
self.receive_queue.produce(msg)
|
|
@@ -148,7 +149,7 @@ class AgentService:
|
|
|
agent_key = 'agent_{}_{}'.format(staff_id, user_id)
|
|
|
if agent_key not in self.agent_registry:
|
|
|
self.agent_registry[agent_key] = DialogueManager(
|
|
|
- staff_id, user_id, self.user_manager, self.agent_state_cache)
|
|
|
+ staff_id, user_id, self.user_manager, self.agent_state_cache, self.AgentDBSession)
|
|
|
agent = self.agent_registry[agent_key]
|
|
|
agent.refresh_profile()
|
|
|
return agent
|
|
@@ -190,7 +191,7 @@ class AgentService:
|
|
|
logger.error("Error processing message: {}, {}".format(e, error_stack))
|
|
|
time.sleep(0.1)
|
|
|
receive_queue.shutdown()
|
|
|
- logger.info("Message processing thread exit")
|
|
|
+ logger.info("MqMessage processing thread exit")
|
|
|
|
|
|
def start(self, blocking=False):
|
|
|
self.running = True
|
|
@@ -255,7 +256,7 @@ class AgentService:
|
|
|
def _schedule_aggregation_trigger(self, staff_id: str, user_id: str, delay_sec: int):
|
|
|
logger.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
|
|
|
message_ts = int((time.time() + delay_sec) * 1000)
|
|
|
- msg = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
|
|
|
+ msg = MqMessage.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
|
|
|
# 系统消息使用特定的msgId,无实际意义
|
|
|
msg.msgId = -MessageType.AGGREGATION_TRIGGER.value
|
|
|
if self.scheduler_mode == 'mq':
|
|
@@ -265,7 +266,7 @@ class AgentService:
|
|
|
'date',
|
|
|
run_date=datetime.now() + timedelta(seconds=delay_sec))
|
|
|
|
|
|
- def process_single_message(self, message: Message):
|
|
|
+ def process_single_message(self, message: MqMessage):
|
|
|
user_id = message.sender
|
|
|
staff_id = message.receiver
|
|
|
|
|
@@ -293,16 +294,8 @@ class AgentService:
|
|
|
elif need_response:
|
|
|
# 先更新用户画像再处理回复
|
|
|
self._update_user_profile(user_id, user_profile, agent.dialogue_history[-10:])
|
|
|
- resp = self._get_chat_response(user_id, agent, message_text)
|
|
|
- if resp:
|
|
|
- recent_dialogue = agent.dialogue_history[-10:]
|
|
|
- agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
|
|
|
- if len(recent_dialogue) < 2 or staff_id not in agent_voice_whitelist:
|
|
|
- message_type = MessageType.TEXT
|
|
|
- else:
|
|
|
- message_type = self.response_type_detector.detect_type(
|
|
|
- recent_dialogue[:-1], recent_dialogue[-1], enable_random=True)
|
|
|
- self.send_response(staff_id, user_id, resp, message_type)
|
|
|
+ resp = self.get_chat_response(agent, message_text)
|
|
|
+ self.send_responses(agent, resp)
|
|
|
else:
|
|
|
logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
|
|
|
# 当前消息处理成功,commit并持久化agent状态
|
|
@@ -311,27 +304,53 @@ class AgentService:
|
|
|
agent.rollback_state()
|
|
|
raise e
|
|
|
|
|
|
- def send_response(self, staff_id, user_id, response, message_type: MessageType, skip_check=False):
|
|
|
- logger.warning(f"staff[{staff_id}] user[{user_id}]: response[{message_type}] {response}")
|
|
|
- current_ts = int(time.time() * 1000)
|
|
|
+ def send_responses(self, agent: DialogueManager, contents: List[Dict]):
|
|
|
+ staff_id = agent.staff_id
|
|
|
+ user_id = agent.user_id
|
|
|
+ recent_dialogue = agent.dialogue_history[-10:]
|
|
|
+ agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
|
|
|
+ for item in contents:
|
|
|
+ if item["type"] == MessageType.TEXT:
|
|
|
+ if staff_id in agent_voice_whitelist:
|
|
|
+ message_type = self.response_type_detector.detect_type(
|
|
|
+ recent_dialogue, item["content"], enable_random=True)
|
|
|
+ item["type"] = message_type
|
|
|
+ if contents:
|
|
|
+ current_ts = int(time.time())
|
|
|
+ for response in contents:
|
|
|
+ self.send_multimodal_response(staff_id, user_id, response, skip_check=True)
|
|
|
+ agent.update_last_active_interaction_time(current_ts)
|
|
|
+ else:
|
|
|
+ logger.debug(f"staff[{staff_id}], user[{user_id}]: no messages to send")
|
|
|
+
|
|
|
+ def can_send_to_user(self, staff_id, user_id) -> bool:
|
|
|
user_tags = self.user_relation_manager.get_user_tags(user_id)
|
|
|
white_list_tags = set(apollo_config.get_json_value("agent_response_whitelist_tags", []))
|
|
|
hit_white_list_tags = len(set(user_tags).intersection(white_list_tags)) > 0
|
|
|
- # FIXME(zhoutian)
|
|
|
- # 测试期间临时逻辑,只发送特定的账号或特定用户
|
|
|
staff_white_lists = set(apollo_config.get_json_value("agent_response_whitelist_staffs", []))
|
|
|
- if not (staff_id in staff_white_lists or hit_white_list_tags or skip_check):
|
|
|
+ if not (staff_id in staff_white_lists or hit_white_list_tags):
|
|
|
logger.warning(f"staff[{staff_id}] user[{user_id}]: skip reply")
|
|
|
+ return False
|
|
|
+ return True
|
|
|
+
|
|
|
+ def send_multimodal_response(self, staff_id, user_id, response: Dict, skip_check=False):
|
|
|
+ message_type = response["type"]
|
|
|
+ logger.warning(f"staff[{staff_id}] user[{user_id}]: response[{message_type}] {response}")
|
|
|
+ if message_type not in (MessageType.TEXT, MessageType.IMAGE_QW, MessageType.VOICE):
|
|
|
+ logger.error(f"staff[{staff_id}] user[{user_id}]: unsupported message type {message_type}")
|
|
|
+ return
|
|
|
+ if not skip_check and not self.can_send_to_user(staff_id, user_id):
|
|
|
return
|
|
|
+ current_ts = int(time.time() * 1000)
|
|
|
self.send_rate_limiter.wait_for_sending(staff_id, response)
|
|
|
self.send_queue.produce(
|
|
|
- Message.build(message_type, MessageChannel.CORP_WECHAT,
|
|
|
- staff_id, user_id, response, current_ts)
|
|
|
+ MqMessage.build(message_type, MessageChannel.CORP_WECHAT,
|
|
|
+ staff_id, user_id, response["content"], current_ts)
|
|
|
)
|
|
|
|
|
|
- def _route_to_human_intervention(self, user_id: str, origin_message: Message):
|
|
|
+ def _route_to_human_intervention(self, user_id: str, origin_message: MqMessage):
|
|
|
"""路由到人工干预"""
|
|
|
- self.human_queue.produce(Message.build(
|
|
|
+ self.human_queue.produce(MqMessage.build(
|
|
|
MessageType.TEXT,
|
|
|
origin_message.channel,
|
|
|
origin_message.sender,
|
|
@@ -392,68 +411,15 @@ class AgentService:
|
|
|
# 问题在于,如果每次创建出新的PushTaskWorkerPool,在上次任务有未处理完的消息即退出时,会有未处理的消息堆积
|
|
|
push_task_worker_pool.wait_to_finish()
|
|
|
|
|
|
- def _check_initiative_conversations_v1(self):
|
|
|
- logger.info("start to check initiative conversations")
|
|
|
- if not DialogueManager.is_time_suitable_for_active_conversation():
|
|
|
- logger.info("time is not suitable for active conversation")
|
|
|
- return
|
|
|
- white_list_tags = set(apollo_config.get_json_value('agent_initiate_whitelist_tags', []))
|
|
|
- first_initiate_tags = set(apollo_config.get_json_value('agent_first_initiate_whitelist_tags', []))
|
|
|
- # 合并白名单,减少配置成本
|
|
|
- white_list_tags.update(first_initiate_tags)
|
|
|
- voice_tags = set(apollo_config.get_json_value('agent_initiate_by_voice_tags', []))
|
|
|
-
|
|
|
-
|
|
|
- """定时检查主动发起对话"""
|
|
|
- for staff_user in self.user_relation_manager.list_staff_users():
|
|
|
- staff_id = staff_user['staff_id']
|
|
|
- user_id = staff_user['user_id']
|
|
|
- agent = self.get_agent_instance(staff_id, user_id)
|
|
|
- should_initiate = agent.should_initiate_conversation()
|
|
|
- user_tags = self.user_relation_manager.get_user_tags(user_id)
|
|
|
-
|
|
|
- if configs.get_env() != 'dev' and not white_list_tags.intersection(user_tags):
|
|
|
- should_initiate = False
|
|
|
-
|
|
|
- if should_initiate:
|
|
|
- logger.warning(f"user[{user_id}], tags{user_tags}: initiate conversation")
|
|
|
- # FIXME:虽然需要主动唤起的用户同时发来消息的概率很低,但仍可能会有并发冲突 需要并入事件驱动框架
|
|
|
- agent.do_state_change(DialogueState.GREETING)
|
|
|
- try:
|
|
|
- if agent.previous_state == DialogueState.INITIALIZED or first_initiate_tags.intersection(user_tags):
|
|
|
- # 完全无交互历史的用户才使用此策略,但新用户接入即会产生“我已添加了你”的消息将Agent初始化
|
|
|
- # 因此存量用户无法使用该状态做实验
|
|
|
- # TODO:增加基于对话历史的判断、策略去重;如果对话间隔过长需要使用长期记忆检索;在无长期记忆时,可采用用户添加时间来判断
|
|
|
- resp = self._generate_active_greeting_message(agent, user_tags)
|
|
|
- else:
|
|
|
- resp = self._get_chat_response(user_id, agent, None)
|
|
|
- if resp:
|
|
|
- if set(user_tags).intersection(voice_tags):
|
|
|
- message_type = MessageType.VOICE
|
|
|
- else:
|
|
|
- message_type = MessageType.TEXT
|
|
|
- self.send_response(staff_id, user_id, resp, message_type, skip_check=True)
|
|
|
- agent.persist_state()
|
|
|
- except Exception as e:
|
|
|
- # FIXME:虽然需要主动唤起的用户同时发来消息的概率很低,但仍可能会有并发冲突
|
|
|
- agent.rollback_state()
|
|
|
- logger.error("Error in active greeting: {}".format(e))
|
|
|
- else:
|
|
|
- logger.debug(f"user[{user_id}], do not initiate conversation")
|
|
|
-
|
|
|
- def _generate_active_greeting_message(self, agent: DialogueManager, user_tags: List[str]=None):
|
|
|
- chat_config = agent.build_active_greeting_config(user_tags)
|
|
|
- chat_response = self._call_chat_api(chat_config, ChatServiceType.OPENAI_COMPATIBLE)
|
|
|
- chat_response = self.sanitize_response(chat_response)
|
|
|
- if response := agent.generate_response(chat_response):
|
|
|
- return response
|
|
|
+ def get_chat_response(self, agent: DialogueManager, user_message: Optional[str]) -> List[Dict]:
|
|
|
+ chat_agent_ver = self.config.get('system', {}).get('chat_agent_version', 1)
|
|
|
+ if chat_agent_ver == 2:
|
|
|
+ return self._get_chat_response_v2(agent)
|
|
|
else:
|
|
|
- logger.warning(f"staff[{agent.staff_id}] user[{agent.user_id}]: no response generated")
|
|
|
- return None
|
|
|
+ text_resp = self._get_chat_response_v1(agent, user_message)
|
|
|
+ return [{"type": MessageType.TEXT, "content": text_resp}] if text_resp else []
|
|
|
|
|
|
- def _get_chat_response(self, user_id: str, agent: DialogueManager,
|
|
|
- user_message: Optional[str]):
|
|
|
- """处理LLM响应"""
|
|
|
+ def _get_chat_response_v1(self, agent: DialogueManager, user_message: Optional[str]) -> Optional[str]:
|
|
|
chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
|
|
|
config_for_logging = chat_config.copy()
|
|
|
config_for_logging['messages'] = config_for_logging['messages'][-20:]
|
|
@@ -464,9 +430,27 @@ class AgentService:
|
|
|
if response := agent.generate_response(chat_response):
|
|
|
return response
|
|
|
else:
|
|
|
- logger.warning(f"staff[{agent.staff_id}] user[{user_id}]: no response generated")
|
|
|
+ logger.warning(f"staff[{agent.staff_id}] user[{agent.user_id}]: no response generated")
|
|
|
return None
|
|
|
|
|
|
+ def _get_chat_response_v2(self, main_agent: DialogueManager) -> List[Dict]:
|
|
|
+ chat_agent = MessageReplyAgent()
|
|
|
+ chat_responses = chat_agent.generate_message(
|
|
|
+ context=main_agent.get_prompt_context(None),
|
|
|
+ dialogue_history=main_agent.dialogue_history[-100:]
|
|
|
+ )
|
|
|
+ if not chat_responses:
|
|
|
+ logger.warning(f"staff[{main_agent.staff_id}] user[{main_agent.user_id}]: no response generated")
|
|
|
+ return []
|
|
|
+ final_responses = []
|
|
|
+ for chat_response in chat_responses:
|
|
|
+ if response := main_agent.generate_multimodal_response(chat_response):
|
|
|
+ final_responses.append(response)
|
|
|
+ else:
|
|
|
+ # 存在非法/结束消息,清空待发消息
|
|
|
+ final_responses.clear()
|
|
|
+ return final_responses
|
|
|
+
|
|
|
def _call_chat_api(self, chat_config: Dict, chat_service_type: ChatServiceType) -> str:
|
|
|
if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
|
|
|
return 'LLM模拟回复 {}'.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
|