Browse Source

Update agent_service: use UserRelationManager

StrayWarrior 2 weeks ago
parent
commit
972862700f
2 changed files with 25 additions and 10 deletions
  1. 23 9
      agent_service.py
  2. 2 1
      unit_test.py

+ 23 - 9
agent_service.py

@@ -13,11 +13,10 @@ from apscheduler.schedulers.background import BackgroundScheduler
 
 import chat_service
 import configs
-import global_flags
 import logging_service
 from chat_service import CozeChat, ChatServiceType
 from dialogue_manager import DialogueManager, DialogueState
-from user_manager import UserManager, LocalUserManager, MySQLUserManager
+from user_manager import UserManager, LocalUserManager, MySQLUserManager, MySQLUserRelationManager, UserRelationManager
 from openai import OpenAI
 from message_queue_backend import MessageQueueBackend, MemoryQueueBackend, AliyunRocketMQQueueBackend
 from user_profile_extractor import UserProfileExtractor
@@ -33,6 +32,7 @@ class AgentService:
         send_backend: MessageQueueBackend,
         human_backend: MessageQueueBackend,
         user_manager: UserManager,
+        user_relation_manager: UserRelationManager,
         chat_service_type: ChatServiceType = ChatServiceType.OPENAI_COMPATIBLE
     ):
         self.receive_queue = receive_backend
@@ -41,6 +41,7 @@ class AgentService:
 
         # 核心服务模块
         self.user_manager = user_manager
+        self.user_relation_manager = user_relation_manager
         self.user_profile_extractor = UserProfileExtractor()
         self.agent_registry: Dict[str, DialogueManager] = {}
 
@@ -149,9 +150,10 @@ class AgentService:
 
     def _check_initiative_conversations(self):
         """定时检查主动发起对话"""
-        for user_id in self.user_manager.list_all_users():
-            #FIXME(zhoutian): 需要企微账号与用户关系
-            agent = self._get_agent_instance('staff_id_0', user_id)
+        for staff_user in self.user_relation_manager.list_staff_users():
+            staff_id = staff_user['staff_id']
+            user_id = staff_user['user_id']
+            agent = self._get_agent_instance(staff_id, user_id)
             should_initiate = agent.should_initiate_conversation()
 
             if should_initiate:
@@ -161,7 +163,7 @@ class AgentService:
                 logging.debug("user: {}, do not initiate conversation".format(user_id))
 
     def _get_chat_response(self, user_id: str, agent: DialogueManager,
-                           user_message: str):
+                           user_message: Optional[str]):
         """处理LLM响应"""
         chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
         logging.debug(chat_config)
@@ -223,11 +225,22 @@ if __name__ == "__main__":
     human_queue = MemoryQueueBackend()
 
     # 初始化用户管理服务
+    # FIXME(zhoutian): 如果不使用MySQL,此数据库配置非必须
+    user_db_config = config['storage']['user']
     if config['debug_flags'].get('use_local_user_manager', False):
         user_manager = LocalUserManager()
     else:
-        db_config = config['storage']['user']
-        user_manager = MySQLUserManager(db_config['mysql'], db_config['table'])
+        user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'])
+
+    wecom_db_config = config['storage']['user_relation']
+    user_relation_manager = MySQLUserRelationManager(
+        user_db_config['mysql'], wecom_db_config['mysql'],
+        config['storage']['staff']['table'],
+        user_db_config['table'],
+        wecom_db_config['table']['staff'],
+        wecom_db_config['table']['relation'],
+        wecom_db_config['table']['user']
+    )
 
     # 创建Agent服务
     service = AgentService(
@@ -235,10 +248,11 @@ if __name__ == "__main__":
         send_backend=send_queue,
         human_backend=human_queue,
         user_manager=user_manager,
+        user_relation_manager=user_relation_manager,
         chat_service_type=ChatServiceType.COZE_CHAT
     )
     # 只有企微场景需要主动发起
-    # service.setup_initiative_conversations({'second': '5,35'})
+    service.setup_initiative_conversations({'second': '5,35'})
 
     process_thread = threading.Thread(target=service.process_messages)
     process_thread.start()

+ 2 - 1
unit_test.py

@@ -37,7 +37,8 @@ def test_env():
         receive_backend=receive_queue,
         send_backend=send_queue,
         human_backend=human_queue,
-        user_manager=user_manager
+        user_manager=user_manager,
+        user_relation_manager=None
     )
     service.user_profile_extractor.extract_profile_info = Mock(return_value=None)