|
@@ -16,7 +16,6 @@ import apscheduler.triggers.cron
|
|
from apscheduler.schedulers.background import BackgroundScheduler
|
|
from apscheduler.schedulers.background import BackgroundScheduler
|
|
|
|
|
|
from pqai_agent import configs
|
|
from pqai_agent import configs
|
|
-from pqai_agent import logging_service
|
|
|
|
from pqai_agent.configs import apollo_config
|
|
from pqai_agent.configs import apollo_config
|
|
from pqai_agent.logging_service import logger
|
|
from pqai_agent.logging_service import logger
|
|
from pqai_agent import chat_service
|
|
from pqai_agent import chat_service
|
|
@@ -24,8 +23,7 @@ from pqai_agent.chat_service import CozeChat, ChatServiceType
|
|
from pqai_agent.dialogue_manager import DialogueManager, DialogueState, DialogueStateCache
|
|
from pqai_agent.dialogue_manager import DialogueManager, DialogueState, DialogueStateCache
|
|
from pqai_agent.rate_limiter import MessageSenderRateLimiter
|
|
from pqai_agent.rate_limiter import MessageSenderRateLimiter
|
|
from pqai_agent.response_type_detector import ResponseTypeDetector
|
|
from pqai_agent.response_type_detector import ResponseTypeDetector
|
|
-from pqai_agent.user_manager import UserManager, LocalUserManager, MySQLUserManager, MySQLUserRelationManager, UserRelationManager, \
|
|
|
|
- LocalUserRelationManager
|
|
|
|
|
|
+from pqai_agent.user_manager import UserManager, UserRelationManager
|
|
from pqai_agent.message_queue_backend import MessageQueueBackend, MemoryQueueBackend, AliyunRocketMQQueueBackend
|
|
from pqai_agent.message_queue_backend import MessageQueueBackend, MemoryQueueBackend, AliyunRocketMQQueueBackend
|
|
from pqai_agent.user_profile_extractor import UserProfileExtractor
|
|
from pqai_agent.user_profile_extractor import UserProfileExtractor
|
|
from pqai_agent.message import MessageType, Message, MessageChannel
|
|
from pqai_agent.message import MessageType, Message, MessageChannel
|
|
@@ -127,7 +125,7 @@ class AgentService:
|
|
else:
|
|
else:
|
|
logger.warning(f"Unknown message type: {msg.type}")
|
|
logger.warning(f"Unknown message type: {msg.type}")
|
|
|
|
|
|
- def _get_agent_instance(self, staff_id: str, user_id: str) -> DialogueManager:
|
|
|
|
|
|
+ def get_agent_instance(self, staff_id: str, user_id: str) -> DialogueManager:
|
|
"""获取Agent实例"""
|
|
"""获取Agent实例"""
|
|
agent_key = 'agent_{}_{}'.format(staff_id, user_id)
|
|
agent_key = 'agent_{}_{}'.format(staff_id, user_id)
|
|
if agent_key not in self.agent_registry:
|
|
if agent_key not in self.agent_registry:
|
|
@@ -151,7 +149,7 @@ class AgentService:
|
|
|
|
|
|
def start(self, blocking=False):
|
|
def start(self, blocking=False):
|
|
self.running = True
|
|
self.running = True
|
|
- self.process_thread = threading.Thread(target=service.process_messages)
|
|
|
|
|
|
+ self.process_thread = threading.Thread(target=self.process_messages)
|
|
self.process_thread.start()
|
|
self.process_thread.start()
|
|
self.setup_scheduler()
|
|
self.setup_scheduler()
|
|
# 只有企微场景需要主动发起
|
|
# 只有企微场景需要主动发起
|
|
@@ -217,7 +215,7 @@ class AgentService:
|
|
|
|
|
|
# 获取用户信息和Agent实例
|
|
# 获取用户信息和Agent实例
|
|
user_profile = self.user_manager.get_user_profile(user_id)
|
|
user_profile = self.user_manager.get_user_profile(user_id)
|
|
- agent = self._get_agent_instance(staff_id, user_id)
|
|
|
|
|
|
+ agent = self.get_agent_instance(staff_id, user_id)
|
|
if not agent.is_valid():
|
|
if not agent.is_valid():
|
|
logger.error(f"staff[{staff_id}] user[{user_id}]: agent is invalid")
|
|
logger.error(f"staff[{staff_id}] user[{user_id}]: agent is invalid")
|
|
return
|
|
return
|
|
@@ -248,7 +246,7 @@ class AgentService:
|
|
else:
|
|
else:
|
|
message_type = self.response_type_detector.detect_type(
|
|
message_type = self.response_type_detector.detect_type(
|
|
recent_dialogue[:-1], recent_dialogue[-1], enable_random=True)
|
|
recent_dialogue[:-1], recent_dialogue[-1], enable_random=True)
|
|
- self._send_response(staff_id, user_id, resp, message_type)
|
|
|
|
|
|
+ self.send_response(staff_id, user_id, resp, message_type)
|
|
else:
|
|
else:
|
|
logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
|
|
logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
|
|
# 当前消息处理成功,commit并持久化agent状态
|
|
# 当前消息处理成功,commit并持久化agent状态
|
|
@@ -257,7 +255,7 @@ class AgentService:
|
|
agent.rollback_state()
|
|
agent.rollback_state()
|
|
raise e
|
|
raise e
|
|
|
|
|
|
- def _send_response(self, staff_id, user_id, response, message_type: MessageType, skip_check=False):
|
|
|
|
|
|
+ def send_response(self, staff_id, user_id, response, message_type: MessageType, skip_check=False):
|
|
logger.warning(f"staff[{staff_id}] user[{user_id}]: response[{message_type}] {response}")
|
|
logger.warning(f"staff[{staff_id}] user[{user_id}]: response[{message_type}] {response}")
|
|
current_ts = int(time.time() * 1000)
|
|
current_ts = int(time.time() * 1000)
|
|
user_tags = self.user_relation_manager.get_user_tags(user_id)
|
|
user_tags = self.user_relation_manager.get_user_tags(user_id)
|
|
@@ -302,7 +300,7 @@ class AgentService:
|
|
for staff_user in self.user_relation_manager.list_staff_users():
|
|
for staff_user in self.user_relation_manager.list_staff_users():
|
|
staff_id = staff_user['staff_id']
|
|
staff_id = staff_user['staff_id']
|
|
user_id = staff_user['user_id']
|
|
user_id = staff_user['user_id']
|
|
- agent = self._get_agent_instance(staff_id, user_id)
|
|
|
|
|
|
+ agent = self.get_agent_instance(staff_id, user_id)
|
|
should_initiate = agent.should_initiate_conversation()
|
|
should_initiate = agent.should_initiate_conversation()
|
|
user_tags = self.user_relation_manager.get_user_tags(user_id)
|
|
user_tags = self.user_relation_manager.get_user_tags(user_id)
|
|
|
|
|
|
@@ -326,7 +324,7 @@ class AgentService:
|
|
message_type = MessageType.VOICE
|
|
message_type = MessageType.VOICE
|
|
else:
|
|
else:
|
|
message_type = MessageType.TEXT
|
|
message_type = MessageType.TEXT
|
|
- self._send_response(staff_id, user_id, resp, message_type, skip_check=True)
|
|
|
|
|
|
+ self.send_response(staff_id, user_id, resp, message_type, skip_check=True)
|
|
agent.persist_state()
|
|
agent.persist_state()
|
|
except Exception as e:
|
|
except Exception as e:
|
|
# FIXME:虽然需要主动唤起的用户同时发来消息的概率很低,但仍可能会有并发冲突
|
|
# FIXME:虽然需要主动唤起的用户同时发来消息的概率很低,但仍可能会有并发冲突
|
|
@@ -398,93 +396,4 @@ class AgentService:
|
|
pattern = r'\[?\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\]?'
|
|
pattern = r'\[?\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\]?'
|
|
response = re.sub(pattern, '', response)
|
|
response = re.sub(pattern, '', response)
|
|
response = response.strip()
|
|
response = response.strip()
|
|
- return response
|
|
|
|
-
|
|
|
|
-if __name__ == "__main__":
|
|
|
|
- config = configs.get()
|
|
|
|
- logging_service.setup_root_logger()
|
|
|
|
- logger.warning("current env: {}".format(configs.get_env()))
|
|
|
|
- scheduler_logger = logging.getLogger('apscheduler')
|
|
|
|
- scheduler_logger.setLevel(logging.WARNING)
|
|
|
|
-
|
|
|
|
- use_aliyun_mq = config['debug_flags']['use_aliyun_mq']
|
|
|
|
-
|
|
|
|
- # 初始化不同队列的后端
|
|
|
|
- if use_aliyun_mq:
|
|
|
|
- receive_queue = AliyunRocketMQQueueBackend(
|
|
|
|
- config['mq']['endpoints'],
|
|
|
|
- config['mq']['instance_id'],
|
|
|
|
- config['mq']['receive_topic'],
|
|
|
|
- has_consumer=True, has_producer=True,
|
|
|
|
- group_id=config['mq']['receive_group'],
|
|
|
|
- topic_type='FIFO'
|
|
|
|
- )
|
|
|
|
- send_queue = AliyunRocketMQQueueBackend(
|
|
|
|
- config['mq']['endpoints'],
|
|
|
|
- config['mq']['instance_id'],
|
|
|
|
- config['mq']['send_topic'],
|
|
|
|
- has_consumer=False, has_producer=True,
|
|
|
|
- topic_type='FIFO'
|
|
|
|
- )
|
|
|
|
- else:
|
|
|
|
- receive_queue = MemoryQueueBackend()
|
|
|
|
- send_queue = MemoryQueueBackend()
|
|
|
|
- human_queue = MemoryQueueBackend()
|
|
|
|
-
|
|
|
|
- # 初始化用户管理服务
|
|
|
|
- # FIXME(zhoutian): 如果不使用MySQL,此数据库配置非必须
|
|
|
|
- user_db_config = config['storage']['user']
|
|
|
|
- staff_db_config = config['storage']['staff']
|
|
|
|
- wecom_db_config = config['storage']['user_relation']
|
|
|
|
- if config['debug_flags'].get('use_local_user_storage', False):
|
|
|
|
- user_manager = LocalUserManager()
|
|
|
|
- user_relation_manager = LocalUserRelationManager()
|
|
|
|
- else:
|
|
|
|
- user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
|
|
|
|
- user_relation_manager = MySQLUserRelationManager(
|
|
|
|
- user_db_config['mysql'], wecom_db_config['mysql'],
|
|
|
|
- config['storage']['staff']['table'],
|
|
|
|
- user_db_config['table'],
|
|
|
|
- wecom_db_config['table']['staff'],
|
|
|
|
- wecom_db_config['table']['relation'],
|
|
|
|
- wecom_db_config['table']['user']
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- # 创建Agent服务
|
|
|
|
- service = AgentService(
|
|
|
|
- receive_backend=receive_queue,
|
|
|
|
- send_backend=send_queue,
|
|
|
|
- human_backend=human_queue,
|
|
|
|
- user_manager=user_manager,
|
|
|
|
- user_relation_manager=user_relation_manager,
|
|
|
|
- chat_service_type=ChatServiceType.COZE_CHAT
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- if not config['debug_flags'].get('console_input', False):
|
|
|
|
- service.start(blocking=True)
|
|
|
|
- sys.exit(0)
|
|
|
|
- else:
|
|
|
|
- service.start()
|
|
|
|
-
|
|
|
|
- message_id = 0
|
|
|
|
- while service.running:
|
|
|
|
- print("Input next message: ")
|
|
|
|
- text = sys.stdin.readline().strip()
|
|
|
|
- if not text:
|
|
|
|
- continue
|
|
|
|
- message_id += 1
|
|
|
|
- sender = '7881301903997433'
|
|
|
|
- receiver = '1688855931724582'
|
|
|
|
- if text in (MessageType.AGGREGATION_TRIGGER.name,
|
|
|
|
- MessageType.HUMAN_INTERVENTION_END.name):
|
|
|
|
- message = Message.build(
|
|
|
|
- MessageType.__members__.get(text),
|
|
|
|
- MessageChannel.CORP_WECHAT,
|
|
|
|
- sender, receiver, None, int(time.time() * 1000))
|
|
|
|
- else:
|
|
|
|
- message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
|
|
|
|
- sender,receiver, text, int(time.time() * 1000)
|
|
|
|
- )
|
|
|
|
- message.msgId = message_id
|
|
|
|
- receive_queue.produce(message)
|
|
|
|
- time.sleep(0.1)
|
|
|
|
|
|
+ return response
|