浏览代码

Move main function outside of agent_service

StrayWarrior 1 月之前
父节点
当前提交
9c73e5c5ec
共有 2 个文件被更改,包括 111 次插入100 次删除
  1. 9 100
      pqai_agent/agent_service.py
  2. 102 0
      pqai_agent_server/agent_server.py

+ 9 - 100
pqai_agent/agent_service.py

@@ -16,7 +16,6 @@ import apscheduler.triggers.cron
 from apscheduler.schedulers.background import BackgroundScheduler
 
 from pqai_agent import configs
-from pqai_agent import logging_service
 from pqai_agent.configs import apollo_config
 from pqai_agent.logging_service import logger
 from pqai_agent import chat_service
@@ -24,8 +23,7 @@ from pqai_agent.chat_service import CozeChat, ChatServiceType
 from pqai_agent.dialogue_manager import DialogueManager, DialogueState, DialogueStateCache
 from pqai_agent.rate_limiter import MessageSenderRateLimiter
 from pqai_agent.response_type_detector import ResponseTypeDetector
-from pqai_agent.user_manager import UserManager, LocalUserManager, MySQLUserManager, MySQLUserRelationManager, UserRelationManager, \
-    LocalUserRelationManager
+from pqai_agent.user_manager import UserManager, UserRelationManager
 from pqai_agent.message_queue_backend import MessageQueueBackend, MemoryQueueBackend, AliyunRocketMQQueueBackend
 from pqai_agent.user_profile_extractor import UserProfileExtractor
 from pqai_agent.message import MessageType, Message, MessageChannel
@@ -127,7 +125,7 @@ class AgentService:
         else:
             logger.warning(f"Unknown message type: {msg.type}")
 
-    def _get_agent_instance(self, staff_id: str, user_id: str) -> DialogueManager:
+    def get_agent_instance(self, staff_id: str, user_id: str) -> DialogueManager:
         """获取Agent实例"""
         agent_key = 'agent_{}_{}'.format(staff_id, user_id)
         if agent_key not in self.agent_registry:
@@ -151,7 +149,7 @@ class AgentService:
 
     def start(self, blocking=False):
         self.running = True
-        self.process_thread = threading.Thread(target=service.process_messages)
+        self.process_thread = threading.Thread(target=self.process_messages)
         self.process_thread.start()
         self.setup_scheduler()
         # 只有企微场景需要主动发起
@@ -217,7 +215,7 @@ class AgentService:
 
         # 获取用户信息和Agent实例
         user_profile = self.user_manager.get_user_profile(user_id)
-        agent = self._get_agent_instance(staff_id, user_id)
+        agent = self.get_agent_instance(staff_id, user_id)
         if not agent.is_valid():
             logger.error(f"staff[{staff_id}] user[{user_id}]: agent is invalid")
             return
@@ -248,7 +246,7 @@ class AgentService:
                     else:
                         message_type = self.response_type_detector.detect_type(
                             recent_dialogue[:-1], recent_dialogue[-1], enable_random=True)
-                    self._send_response(staff_id, user_id, resp, message_type)
+                    self.send_response(staff_id, user_id, resp, message_type)
             else:
                 logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
             # 当前消息处理成功,commit并持久化agent状态
@@ -257,7 +255,7 @@ class AgentService:
             agent.rollback_state()
             raise e
 
-    def _send_response(self, staff_id, user_id, response, message_type: MessageType, skip_check=False):
+    def send_response(self, staff_id, user_id, response, message_type: MessageType, skip_check=False):
         logger.warning(f"staff[{staff_id}] user[{user_id}]: response[{message_type}] {response}")
         current_ts = int(time.time() * 1000)
         user_tags = self.user_relation_manager.get_user_tags(user_id)
@@ -302,7 +300,7 @@ class AgentService:
         for staff_user in self.user_relation_manager.list_staff_users():
             staff_id = staff_user['staff_id']
             user_id = staff_user['user_id']
-            agent = self._get_agent_instance(staff_id, user_id)
+            agent = self.get_agent_instance(staff_id, user_id)
             should_initiate = agent.should_initiate_conversation()
             user_tags = self.user_relation_manager.get_user_tags(user_id)
 
@@ -326,7 +324,7 @@ class AgentService:
                             message_type = MessageType.VOICE
                         else:
                             message_type = MessageType.TEXT
-                        self._send_response(staff_id, user_id, resp, message_type, skip_check=True)
+                        self.send_response(staff_id, user_id, resp, message_type, skip_check=True)
                     agent.persist_state()
                 except Exception as e:
                     # FIXME:虽然需要主动唤起的用户同时发来消息的概率很低,但仍可能会有并发冲突
@@ -398,93 +396,4 @@ class AgentService:
         pattern = r'\[?\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\]?'
         response = re.sub(pattern, '', response)
         response = response.strip()
-        return response
-
-if __name__ == "__main__":
-    config = configs.get()
-    logging_service.setup_root_logger()
-    logger.warning("current env: {}".format(configs.get_env()))
-    scheduler_logger = logging.getLogger('apscheduler')
-    scheduler_logger.setLevel(logging.WARNING)
-
-    use_aliyun_mq = config['debug_flags']['use_aliyun_mq']
-
-    # 初始化不同队列的后端
-    if use_aliyun_mq:
-        receive_queue = AliyunRocketMQQueueBackend(
-            config['mq']['endpoints'],
-            config['mq']['instance_id'],
-            config['mq']['receive_topic'],
-            has_consumer=True, has_producer=True,
-            group_id=config['mq']['receive_group'],
-            topic_type='FIFO'
-        )
-        send_queue = AliyunRocketMQQueueBackend(
-            config['mq']['endpoints'],
-            config['mq']['instance_id'],
-            config['mq']['send_topic'],
-            has_consumer=False, has_producer=True,
-            topic_type='FIFO'
-        )
-    else:
-        receive_queue = MemoryQueueBackend()
-        send_queue = MemoryQueueBackend()
-    human_queue = MemoryQueueBackend()
-
-    # 初始化用户管理服务
-    # FIXME(zhoutian): 如果不使用MySQL,此数据库配置非必须
-    user_db_config = config['storage']['user']
-    staff_db_config = config['storage']['staff']
-    wecom_db_config = config['storage']['user_relation']
-    if config['debug_flags'].get('use_local_user_storage', False):
-        user_manager = LocalUserManager()
-        user_relation_manager = LocalUserRelationManager()
-    else:
-        user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
-        user_relation_manager = MySQLUserRelationManager(
-            user_db_config['mysql'], wecom_db_config['mysql'],
-            config['storage']['staff']['table'],
-            user_db_config['table'],
-            wecom_db_config['table']['staff'],
-            wecom_db_config['table']['relation'],
-            wecom_db_config['table']['user']
-        )
-
-    # 创建Agent服务
-    service = AgentService(
-        receive_backend=receive_queue,
-        send_backend=send_queue,
-        human_backend=human_queue,
-        user_manager=user_manager,
-        user_relation_manager=user_relation_manager,
-        chat_service_type=ChatServiceType.COZE_CHAT
-    )
-
-    if not config['debug_flags'].get('console_input', False):
-        service.start(blocking=True)
-        sys.exit(0)
-    else:
-        service.start()
-
-    message_id = 0
-    while service.running:
-        print("Input next message: ")
-        text = sys.stdin.readline().strip()
-        if not text:
-            continue
-        message_id += 1
-        sender = '7881301903997433'
-        receiver = '1688855931724582'
-        if text in (MessageType.AGGREGATION_TRIGGER.name,
-                    MessageType.HUMAN_INTERVENTION_END.name):
-            message = Message.build(
-                MessageType.__members__.get(text),
-                MessageChannel.CORP_WECHAT,
-                sender, receiver, None, int(time.time() * 1000))
-        else:
-            message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
-                sender,receiver, text, int(time.time() * 1000)
-            )
-        message.msgId = message_id
-        receive_queue.produce(message)
-        time.sleep(0.1)
+        return response

+ 102 - 0
pqai_agent_server/agent_server.py

@@ -0,0 +1,102 @@
+import logging
+import sys
+import time
+
+from pqai_agent import configs, logging_service
+from pqai_agent.agent_service import AgentService
+from pqai_agent.chat_service import ChatServiceType
+from pqai_agent.logging_service import logger
+from pqai_agent.message import MessageType, Message, MessageChannel
+from pqai_agent.message_queue_backend import AliyunRocketMQQueueBackend, MemoryQueueBackend
+from pqai_agent.push_service import PushTaskWorkerPool, PushScanThread
+from pqai_agent.user_manager import LocalUserManager, LocalUserRelationManager, MySQLUserManager, \
+    MySQLUserRelationManager
+
+if __name__ == "__main__":
+    config = configs.get()
+    logging_service.setup_root_logger()
+    logger.warning("current env: {}".format(configs.get_env()))
+    scheduler_logger = logging.getLogger('apscheduler')
+    scheduler_logger.setLevel(logging.WARNING)
+
+    use_aliyun_mq = config['debug_flags']['use_aliyun_mq']
+
+    # 初始化不同队列的后端
+    if use_aliyun_mq:
+        receive_queue = AliyunRocketMQQueueBackend(
+            config['mq']['endpoints'],
+            config['mq']['instance_id'],
+            config['mq']['receive_topic'],
+            has_consumer=True, has_producer=True,
+            group_id=config['mq']['receive_group'],
+            topic_type='FIFO'
+        )
+        send_queue = AliyunRocketMQQueueBackend(
+            config['mq']['endpoints'],
+            config['mq']['instance_id'],
+            config['mq']['send_topic'],
+            has_consumer=False, has_producer=True,
+            topic_type='FIFO'
+        )
+    else:
+        receive_queue = MemoryQueueBackend()
+        send_queue = MemoryQueueBackend()
+    human_queue = MemoryQueueBackend()
+
+    # 初始化用户管理服务
+    # FIXME(zhoutian): 如果不使用MySQL,此数据库配置非必须
+    user_db_config = config['storage']['user']
+    staff_db_config = config['storage']['staff']
+    wecom_db_config = config['storage']['user_relation']
+    if config['debug_flags'].get('use_local_user_storage', False):
+        user_manager = LocalUserManager()
+        user_relation_manager = LocalUserRelationManager()
+    else:
+        user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
+        user_relation_manager = MySQLUserRelationManager(
+            user_db_config['mysql'], wecom_db_config['mysql'],
+            config['storage']['staff']['table'],
+            user_db_config['table'],
+            wecom_db_config['table']['staff'],
+            wecom_db_config['table']['relation'],
+            wecom_db_config['table']['user']
+        )
+
+    # 创建Agent服务
+    service = AgentService(
+        receive_backend=receive_queue,
+        send_backend=send_queue,
+        human_backend=human_queue,
+        user_manager=user_manager,
+        user_relation_manager=user_relation_manager,
+        chat_service_type=ChatServiceType.COZE_CHAT
+    )
+
+    if not config['debug_flags'].get('console_input', False):
+        service.start(blocking=True)
+        sys.exit(0)
+    else:
+        service.start()
+
+    message_id = 0
+    while service.running:
+        print("Input next message: ")
+        text = sys.stdin.readline().strip()
+        if not text:
+            continue
+        message_id += 1
+        sender = '7881301903997433'
+        receiver = '1688855931724582'
+        if text in (MessageType.AGGREGATION_TRIGGER.name,
+                    MessageType.HUMAN_INTERVENTION_END.name):
+            message = Message.build(
+                MessageType.__members__.get(text),
+                MessageChannel.CORP_WECHAT,
+                sender, receiver, None, int(time.time() * 1000))
+        else:
+            message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
+                                    sender,receiver, text, int(time.time() * 1000)
+                                    )
+        message.msgId = message_id
+        receive_queue.produce(message)
+        time.sleep(0.1)