|
@@ -13,24 +13,25 @@ import threading
|
|
|
import traceback
|
|
|
|
|
|
import apscheduler.triggers.cron
|
|
|
+import rocketmq
|
|
|
from apscheduler.schedulers.background import BackgroundScheduler
|
|
|
|
|
|
from pqai_agent import configs
|
|
|
-from pqai_agent import logging_service
|
|
|
from pqai_agent.configs import apollo_config
|
|
|
+from pqai_agent.exceptions import NoRetryException
|
|
|
from pqai_agent.logging_service import logger
|
|
|
from pqai_agent import chat_service
|
|
|
from pqai_agent.chat_service import CozeChat, ChatServiceType
|
|
|
from pqai_agent.dialogue_manager import DialogueManager, DialogueState, DialogueStateCache
|
|
|
+from pqai_agent.history_dialogue_service import HistoryDialogueDatabase
|
|
|
+from pqai_agent.push_service import PushScanThread, PushTaskWorkerPool
|
|
|
from pqai_agent.rate_limiter import MessageSenderRateLimiter
|
|
|
from pqai_agent.response_type_detector import ResponseTypeDetector
|
|
|
-from pqai_agent.user_manager import UserManager, LocalUserManager, MySQLUserManager, MySQLUserRelationManager, UserRelationManager, \
|
|
|
- LocalUserRelationManager
|
|
|
+from pqai_agent.user_manager import UserManager, UserRelationManager
|
|
|
from pqai_agent.message_queue_backend import MessageQueueBackend, MemoryQueueBackend, AliyunRocketMQQueueBackend
|
|
|
from pqai_agent.user_profile_extractor import UserProfileExtractor
|
|
|
from pqai_agent.message import MessageType, Message, MessageChannel
|
|
|
|
|
|
-
|
|
|
class AgentService:
|
|
|
def __init__(
|
|
|
self,
|
|
@@ -41,6 +42,8 @@ class AgentService:
|
|
|
user_relation_manager: UserRelationManager,
|
|
|
chat_service_type: ChatServiceType = ChatServiceType.OPENAI_COMPATIBLE
|
|
|
):
|
|
|
+ self.config = configs.get()
|
|
|
+
|
|
|
self.receive_queue = receive_backend
|
|
|
self.send_queue = send_backend
|
|
|
self.human_queue = human_backend
|
|
@@ -52,8 +55,8 @@ class AgentService:
|
|
|
self.user_profile_extractor = UserProfileExtractor()
|
|
|
self.response_type_detector = ResponseTypeDetector()
|
|
|
self.agent_registry: Dict[str, DialogueManager] = {}
|
|
|
+ self.history_dialogue_db = HistoryDialogueDatabase(self.config['storage']['user']['mysql'])
|
|
|
|
|
|
- self.config = configs.get()
|
|
|
chat_config = self.config['chat_api']['openai_compatible']
|
|
|
self.text_model_name = chat_config['text_model']
|
|
|
self.multimodal_model_name = chat_config['multimodal_model']
|
|
@@ -80,6 +83,13 @@ class AgentService:
|
|
|
self.process_thread = None
|
|
|
self._sigint_cnt = 0
|
|
|
|
|
|
+ # Push相关
|
|
|
+ self.push_task_producer = None
|
|
|
+ self.push_task_consumer = None
|
|
|
+ self._init_push_task_queue()
|
|
|
+ self.next_push_disabled = True
|
|
|
+ self._resume_unfinished_push_task()
|
|
|
+
|
|
|
self.send_rate_limiter = MessageSenderRateLimiter()
|
|
|
|
|
|
def setup_initiative_conversations(self, schedule_params: Optional[Dict] = None):
|
|
@@ -102,7 +112,8 @@ class AgentService:
|
|
|
topic,
|
|
|
has_consumer=True, has_producer=True,
|
|
|
group_id=mq_conf['scheduler_group'],
|
|
|
- topic_type='DELAY'
|
|
|
+ topic_type='DELAY',
|
|
|
+ await_duration=5
|
|
|
)
|
|
|
self.msg_scheduler_thread = threading.Thread(target=self.process_scheduler_events)
|
|
|
self.msg_scheduler_thread.start()
|
|
@@ -127,13 +138,15 @@ class AgentService:
|
|
|
else:
|
|
|
logger.warning(f"Unknown message type: {msg.type}")
|
|
|
|
|
|
- def _get_agent_instance(self, staff_id: str, user_id: str) -> DialogueManager:
|
|
|
+ def get_agent_instance(self, staff_id: str, user_id: str) -> DialogueManager:
|
|
|
"""获取Agent实例"""
|
|
|
agent_key = 'agent_{}_{}'.format(staff_id, user_id)
|
|
|
if agent_key not in self.agent_registry:
|
|
|
self.agent_registry[agent_key] = DialogueManager(
|
|
|
staff_id, user_id, self.user_manager, self.agent_state_cache)
|
|
|
- return self.agent_registry[agent_key]
|
|
|
+ agent = self.agent_registry[agent_key]
|
|
|
+ agent.refresh_profile()
|
|
|
+ return agent
|
|
|
|
|
|
def process_messages(self):
|
|
|
"""持续处理接收队列消息"""
|
|
@@ -143,15 +156,18 @@ class AgentService:
|
|
|
try:
|
|
|
self.process_single_message(message)
|
|
|
self.receive_queue.ack(message)
|
|
|
+ except NoRetryException as e:
|
|
|
+ logger.error("Error processing message and skip retry: {}".format(e))
|
|
|
+ self.receive_queue.ack(message)
|
|
|
except Exception as e:
|
|
|
- logger.error("Error processing message: {}".format(e))
|
|
|
- traceback.print_exc()
|
|
|
- time.sleep(1)
|
|
|
+ error_stack = traceback.format_exc()
|
|
|
+ logger.error("Error processing message: {}, {}".format(e, error_stack))
|
|
|
+ time.sleep(0.5)
|
|
|
logger.info("Message processing thread exit")
|
|
|
|
|
|
def start(self, blocking=False):
|
|
|
self.running = True
|
|
|
- self.process_thread = threading.Thread(target=service.process_messages)
|
|
|
+ self.process_thread = threading.Thread(target=self.process_messages)
|
|
|
self.process_thread.start()
|
|
|
self.setup_scheduler()
|
|
|
# 只有企微场景需要主动发起
|
|
@@ -217,10 +233,10 @@ class AgentService:
|
|
|
|
|
|
# 获取用户信息和Agent实例
|
|
|
user_profile = self.user_manager.get_user_profile(user_id)
|
|
|
- agent = self._get_agent_instance(staff_id, user_id)
|
|
|
+ agent = self.get_agent_instance(staff_id, user_id)
|
|
|
if not agent.is_valid():
|
|
|
logger.error(f"staff[{staff_id}] user[{user_id}]: agent is invalid")
|
|
|
- return
|
|
|
+ raise Exception('agent is invalid')
|
|
|
|
|
|
# 更新对话状态
|
|
|
logger.debug("process message: {}".format(message))
|
|
@@ -242,13 +258,13 @@ class AgentService:
|
|
|
resp = self._get_chat_response(user_id, agent, message_text)
|
|
|
if resp:
|
|
|
recent_dialogue = agent.dialogue_history[-10:]
|
|
|
- agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist"))
|
|
|
+ agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
|
|
|
if len(recent_dialogue) < 2 or staff_id not in agent_voice_whitelist:
|
|
|
message_type = MessageType.TEXT
|
|
|
else:
|
|
|
message_type = self.response_type_detector.detect_type(
|
|
|
recent_dialogue[:-1], recent_dialogue[-1], enable_random=True)
|
|
|
- self._send_response(staff_id, user_id, resp, message_type)
|
|
|
+ self.send_response(staff_id, user_id, resp, message_type)
|
|
|
else:
|
|
|
logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
|
|
|
# 当前消息处理成功,commit并持久化agent状态
|
|
@@ -257,15 +273,15 @@ class AgentService:
|
|
|
agent.rollback_state()
|
|
|
raise e
|
|
|
|
|
|
- def _send_response(self, staff_id, user_id, response, message_type: MessageType, skip_check=False):
|
|
|
+ def send_response(self, staff_id, user_id, response, message_type: MessageType, skip_check=False):
|
|
|
logger.warning(f"staff[{staff_id}] user[{user_id}]: response[{message_type}] {response}")
|
|
|
current_ts = int(time.time() * 1000)
|
|
|
user_tags = self.user_relation_manager.get_user_tags(user_id)
|
|
|
- white_list_tags = set(apollo_config.get_json_value("agent_response_whitelist_tags"))
|
|
|
+ white_list_tags = set(apollo_config.get_json_value("agent_response_whitelist_tags", []))
|
|
|
hit_white_list_tags = len(set(user_tags).intersection(white_list_tags)) > 0
|
|
|
# FIXME(zhoutian)
|
|
|
# 测试期间临时逻辑,只发送特定的账号或特定用户
|
|
|
- staff_white_lists = set(apollo_config.get_json_value("agent_response_whitelist_staffs"))
|
|
|
+ staff_white_lists = set(apollo_config.get_json_value("agent_response_whitelist_staffs", []))
|
|
|
if not (staff_id in staff_white_lists or hit_white_list_tags or skip_check):
|
|
|
logger.warning(f"staff[{staff_id}] user[{user_id}]: skip reply")
|
|
|
return
|
|
@@ -286,23 +302,75 @@ class AgentService:
|
|
|
int(time.time() * 1000)
|
|
|
))
|
|
|
|
|
|
+ def _init_push_task_queue(self):
|
|
|
+ credentials = rocketmq.Credentials()
|
|
|
+ mq_conf = configs.get()['mq']
|
|
|
+ rmq_client_conf = rocketmq.ClientConfiguration(mq_conf['endpoints'], credentials, mq_conf['instance_id'])
|
|
|
+ rmq_topic = mq_conf['push_tasks_topic']
|
|
|
+ rmq_group = mq_conf['push_tasks_group']
|
|
|
+ self.push_task_rmq_topic = rmq_topic
|
|
|
+ self.push_task_producer = rocketmq.Producer(rmq_client_conf, (rmq_topic,))
|
|
|
+ self.push_task_producer.startup()
|
|
|
+ self.push_task_consumer = rocketmq.SimpleConsumer(rmq_client_conf, rmq_group, await_duration=5)
|
|
|
+ self.push_task_consumer.startup()
|
|
|
+ self.push_task_consumer.subscribe(rmq_topic)
|
|
|
+
|
|
|
+
|
|
|
+ def _resume_unfinished_push_task(self):
|
|
|
+ def run_unfinished_push_task():
|
|
|
+ logger.info("start to resume unfinished push task")
|
|
|
+ push_task_worker_pool = PushTaskWorkerPool(
|
|
|
+ self, self.push_task_rmq_topic, self.push_task_consumer, self.push_task_producer)
|
|
|
+ push_task_worker_pool.start()
|
|
|
+ push_task_worker_pool.wait_to_finish()
|
|
|
+ self.next_push_disabled = False
|
|
|
+ logger.info("unfinished push tasks should be finished")
|
|
|
+ thread = threading.Thread(target=run_unfinished_push_task)
|
|
|
+ thread.start()
|
|
|
+
|
|
|
def _check_initiative_conversations(self):
|
|
|
logger.info("start to check initiative conversations")
|
|
|
+ if self.next_push_disabled:
|
|
|
+ logger.info("previous push tasks in processing, next push is disabled")
|
|
|
+ return
|
|
|
if not DialogueManager.is_time_suitable_for_active_conversation():
|
|
|
logger.info("time is not suitable for active conversation")
|
|
|
return
|
|
|
- white_list_tags = set(apollo_config.get_json_value('agent_initiate_whitelist_tags'))
|
|
|
+
|
|
|
+ push_scan_threads = []
|
|
|
+ for staff in self.user_relation_manager.list_staffs():
|
|
|
+ staff_id = staff['third_party_user_id']
|
|
|
+ scan_thread = threading.Thread(target=PushScanThread(
|
|
|
+ staff_id, self, self.push_task_rmq_topic, self.push_task_producer).run)
|
|
|
+ scan_thread.start()
|
|
|
+ push_scan_threads.append(scan_thread)
|
|
|
+
|
|
|
+ push_task_worker_pool = PushTaskWorkerPool(
|
|
|
+ self, self.push_task_rmq_topic, self.push_task_consumer, self.push_task_producer)
|
|
|
+ push_task_worker_pool.start()
|
|
|
+ for thread in push_scan_threads:
|
|
|
+ thread.join()
|
|
|
+ # 由于扫描和生成异步,两次扫描之间可能消息并未处理完,会有重复生成任务的情况,因此需等待上一次任务结束
|
|
|
+ # 问题在于,如果每次创建出新的PushTaskWorkerPool,在上次任务有未处理完的消息即退出时,会有未处理的消息堆积
|
|
|
+ push_task_worker_pool.wait_to_finish()
|
|
|
+
|
|
|
+ def _check_initiative_conversations_v1(self):
|
|
|
+ logger.info("start to check initiative conversations")
|
|
|
+ if not DialogueManager.is_time_suitable_for_active_conversation():
|
|
|
+ logger.info("time is not suitable for active conversation")
|
|
|
+ return
|
|
|
+ white_list_tags = set(apollo_config.get_json_value('agent_initiate_whitelist_tags', []))
|
|
|
first_initiate_tags = set(apollo_config.get_json_value('agent_first_initiate_whitelist_tags', []))
|
|
|
# 合并白名单,减少配置成本
|
|
|
white_list_tags.update(first_initiate_tags)
|
|
|
- voice_tags = set(apollo_config.get_json_value('agent_initiate_by_voice_tags'))
|
|
|
+ voice_tags = set(apollo_config.get_json_value('agent_initiate_by_voice_tags', []))
|
|
|
|
|
|
|
|
|
"""定时检查主动发起对话"""
|
|
|
for staff_user in self.user_relation_manager.list_staff_users():
|
|
|
staff_id = staff_user['staff_id']
|
|
|
user_id = staff_user['user_id']
|
|
|
- agent = self._get_agent_instance(staff_id, user_id)
|
|
|
+ agent = self.get_agent_instance(staff_id, user_id)
|
|
|
should_initiate = agent.should_initiate_conversation()
|
|
|
user_tags = self.user_relation_manager.get_user_tags(user_id)
|
|
|
|
|
@@ -326,7 +394,7 @@ class AgentService:
|
|
|
message_type = MessageType.VOICE
|
|
|
else:
|
|
|
message_type = MessageType.TEXT
|
|
|
- self._send_response(staff_id, user_id, resp, message_type, skip_check=True)
|
|
|
+ self.send_response(staff_id, user_id, resp, message_type, skip_check=True)
|
|
|
agent.persist_state()
|
|
|
except Exception as e:
|
|
|
# FIXME:虽然需要主动唤起的用户同时发来消息的概率很低,但仍可能会有并发冲突
|
|
@@ -398,93 +466,4 @@ class AgentService:
|
|
|
pattern = r'\[?\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\]?'
|
|
|
response = re.sub(pattern, '', response)
|
|
|
response = response.strip()
|
|
|
- return response
|
|
|
-
|
|
|
-if __name__ == "__main__":
|
|
|
- config = configs.get()
|
|
|
- logging_service.setup_root_logger()
|
|
|
- logger.warning("current env: {}".format(configs.get_env()))
|
|
|
- scheduler_logger = logging.getLogger('apscheduler')
|
|
|
- scheduler_logger.setLevel(logging.WARNING)
|
|
|
-
|
|
|
- use_aliyun_mq = config['debug_flags']['use_aliyun_mq']
|
|
|
-
|
|
|
- # 初始化不同队列的后端
|
|
|
- if use_aliyun_mq:
|
|
|
- receive_queue = AliyunRocketMQQueueBackend(
|
|
|
- config['mq']['endpoints'],
|
|
|
- config['mq']['instance_id'],
|
|
|
- config['mq']['receive_topic'],
|
|
|
- has_consumer=True, has_producer=True,
|
|
|
- group_id=config['mq']['receive_group'],
|
|
|
- topic_type='FIFO'
|
|
|
- )
|
|
|
- send_queue = AliyunRocketMQQueueBackend(
|
|
|
- config['mq']['endpoints'],
|
|
|
- config['mq']['instance_id'],
|
|
|
- config['mq']['send_topic'],
|
|
|
- has_consumer=False, has_producer=True,
|
|
|
- topic_type='FIFO'
|
|
|
- )
|
|
|
- else:
|
|
|
- receive_queue = MemoryQueueBackend()
|
|
|
- send_queue = MemoryQueueBackend()
|
|
|
- human_queue = MemoryQueueBackend()
|
|
|
-
|
|
|
- # 初始化用户管理服务
|
|
|
- # FIXME(zhoutian): 如果不使用MySQL,此数据库配置非必须
|
|
|
- user_db_config = config['storage']['user']
|
|
|
- staff_db_config = config['storage']['staff']
|
|
|
- wecom_db_config = config['storage']['user_relation']
|
|
|
- if config['debug_flags'].get('use_local_user_storage', False):
|
|
|
- user_manager = LocalUserManager()
|
|
|
- user_relation_manager = LocalUserRelationManager()
|
|
|
- else:
|
|
|
- user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
|
|
|
- user_relation_manager = MySQLUserRelationManager(
|
|
|
- user_db_config['mysql'], wecom_db_config['mysql'],
|
|
|
- config['storage']['staff']['table'],
|
|
|
- user_db_config['table'],
|
|
|
- wecom_db_config['table']['staff'],
|
|
|
- wecom_db_config['table']['relation'],
|
|
|
- wecom_db_config['table']['user']
|
|
|
- )
|
|
|
-
|
|
|
- # 创建Agent服务
|
|
|
- service = AgentService(
|
|
|
- receive_backend=receive_queue,
|
|
|
- send_backend=send_queue,
|
|
|
- human_backend=human_queue,
|
|
|
- user_manager=user_manager,
|
|
|
- user_relation_manager=user_relation_manager,
|
|
|
- chat_service_type=ChatServiceType.COZE_CHAT
|
|
|
- )
|
|
|
-
|
|
|
- if not config['debug_flags'].get('console_input', False):
|
|
|
- service.start(blocking=True)
|
|
|
- sys.exit(0)
|
|
|
- else:
|
|
|
- service.start()
|
|
|
-
|
|
|
- message_id = 0
|
|
|
- while service.running:
|
|
|
- print("Input next message: ")
|
|
|
- text = sys.stdin.readline().strip()
|
|
|
- if not text:
|
|
|
- continue
|
|
|
- message_id += 1
|
|
|
- sender = '7881301903997433'
|
|
|
- receiver = '1688855931724582'
|
|
|
- if text in (MessageType.AGGREGATION_TRIGGER.name,
|
|
|
- MessageType.HUMAN_INTERVENTION_END.name):
|
|
|
- message = Message.build(
|
|
|
- MessageType.__members__.get(text),
|
|
|
- MessageChannel.CORP_WECHAT,
|
|
|
- sender, receiver, None, int(time.time() * 1000))
|
|
|
- else:
|
|
|
- message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
|
|
|
- sender,receiver, text, int(time.time() * 1000)
|
|
|
- )
|
|
|
- message.msgId = message_id
|
|
|
- receive_queue.produce(message)
|
|
|
- time.sleep(0.1)
|
|
|
+ return response
|