Prechádzať zdrojové kódy

Update agent_service: skip whitelist for initiative conversation in dev environment

StrayWarrior 2 týždňov pred
rodič
commit
74fc7de714
1 zmenil súbory, kde vykonal 16 pridanie a 12 odobranie
  1. 16 12
      agent_service.py

+ 16 - 12
agent_service.py

@@ -22,7 +22,8 @@ from logging_service import logger
 from chat_service import CozeChat, ChatServiceType
 from dialogue_manager import DialogueManager, DialogueState, DialogueStateCache
 from response_type_detector import ResponseTypeDetector
-from user_manager import UserManager, LocalUserManager, MySQLUserManager, MySQLUserRelationManager, UserRelationManager
+from user_manager import UserManager, LocalUserManager, MySQLUserManager, MySQLUserRelationManager, UserRelationManager, \
+    LocalUserRelationManager
 from openai import OpenAI
 from message_queue_backend import MessageQueueBackend, MemoryQueueBackend, AliyunRocketMQQueueBackend
 from user_profile_extractor import UserProfileExtractor
@@ -52,7 +53,8 @@ class AgentService:
         self.response_type_detector = ResponseTypeDetector()
         self.agent_registry: Dict[str, DialogueManager] = {}
 
-        chat_config = configs.get()['chat_api']['openai_compatible']
+        self.config = configs.get()
+        chat_config = self.config['chat_api']['openai_compatible']
         self.text_model_name = chat_config['text_model']
         self.multimodal_model_name = chat_config['multimodal_model']
         self.text_model_client = chat_service.OpenAICompatible.create_client(self.text_model_name)
@@ -232,7 +234,7 @@ class AgentService:
             should_initiate = agent.should_initiate_conversation()
             user_tags = self.user_relation_manager.get_user_tags(user_id)
             white_list_tags = apollo_config.get_json_value('agent_initiate_whitelist_tags')
-            if not set(user_tags).intersection(white_list_tags):
+            if configs.get_env() != 'dev' and not set(user_tags).intersection(white_list_tags):
                 should_initiate = False
 
             if should_initiate:
@@ -324,20 +326,22 @@ if __name__ == "__main__":
     # FIXME(zhoutian): 如果不使用MySQL,此数据库配置非必须
     user_db_config = config['storage']['user']
     staff_db_config = config['storage']['staff']
+    wecom_db_config = config['storage']['user_relation']
     if config['debug_flags'].get('use_local_user_storage', False):
         user_manager = LocalUserManager()
+        user_relation_manager = LocalUserRelationManager()
     else:
         user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
+        user_relation_manager = MySQLUserRelationManager(
+            user_db_config['mysql'], wecom_db_config['mysql'],
+            config['storage']['staff']['table'],
+            user_db_config['table'],
+            wecom_db_config['table']['staff'],
+            wecom_db_config['table']['relation'],
+            wecom_db_config['table']['user']
+        )
+
 
-    wecom_db_config = config['storage']['user_relation']
-    user_relation_manager = MySQLUserRelationManager(
-        user_db_config['mysql'], wecom_db_config['mysql'],
-        config['storage']['staff']['table'],
-        user_db_config['table'],
-        wecom_db_config['table']['staff'],
-        wecom_db_config['table']['relation'],
-        wecom_db_config['table']['user']
-    )
 
     # 创建Agent服务
     service = AgentService(