Sfoglia il codice sorgente

Update dialogue_manager and agent_service: for push experiment

StrayWarrior 1 settimana fa
parent
commit
2181c14618
2 ha cambiato i file con 29 aggiunte e 12 eliminazioni
  1. 16 10
      agent_service.py
  2. 13 2
      dialogue_manager.py

+ 16 - 10
agent_service.py

@@ -196,7 +196,7 @@ class AgentService:
             agent.rollback_state()
             raise e
 
-    def _send_response(self, staff_id, user_id, response, message_type: MessageType):
+    def _send_response(self, staff_id, user_id, response, message_type: MessageType, skip_check=False):
         logger.warning(f"staff[{staff_id}] user[{user_id}]: response[{message_type}] {response}")
         current_ts = int(time.time() * 1000)
         user_tags = self.user_relation_manager.get_user_tags(user_id)
@@ -205,7 +205,7 @@ class AgentService:
         # FIXME(zhoutian)
         # 测试期间临时逻辑,只发送特定的账号或特定用户
         staff_white_lists = set(apollo_config.get_json_value("agent_response_whitelist_staffs"))
-        if not (staff_id in staff_white_lists or hit_white_list_tags):
+        if not (staff_id in staff_white_lists or hit_white_list_tags or skip_check):
             logger.warning(f"staff[{staff_id}] user[{user_id}]: skip reply")
             return None
         self.send_queue.produce(
@@ -226,6 +226,13 @@ class AgentService:
 
     def _check_initiative_conversations(self):
         logger.info("start to check initiative conversations")
+        white_list_tags = set(apollo_config.get_json_value('agent_initiate_whitelist_tags'))
+        first_initiate_tags = set(apollo_config.get_json_value('agent_first_initiate_whitelist_tags', []))
+        # 合并白名单,减少配置成本
+        white_list_tags.update(first_initiate_tags)
+        voice_tags = set(apollo_config.get_json_value('agent_initiate_by_voice_tags'))
+
+
         """定时检查主动发起对话"""
         for staff_user in self.user_relation_manager.list_staff_users():
             staff_id = staff_user['staff_id']
@@ -233,9 +240,8 @@ class AgentService:
             agent = self._get_agent_instance(staff_id, user_id)
             should_initiate = agent.should_initiate_conversation()
             user_tags = self.user_relation_manager.get_user_tags(user_id)
-            white_list_tags = set(apollo_config.get_json_value('agent_initiate_whitelist_tags'))
-            voice_tags = set(apollo_config.get_json_value('agent_initiate_by_voice_tags'))
-            if configs.get_env() != 'dev' and not set(user_tags).intersection(white_list_tags):
+
+            if configs.get_env() != 'dev' and not white_list_tags.intersection(user_tags):
                 should_initiate = False
 
             if should_initiate:
@@ -243,9 +249,9 @@ class AgentService:
                 # FIXME:虽然需要主动唤起的用户同时发来消息的概率很低,但仍可能会有并发冲突 需要并入事件驱动框架
                 agent.do_state_change(DialogueState.GREETING)
                 try:
-                    if agent.previous_state == DialogueState.INITIALIZED:
-                        # 完全无交互历史的用户才使用此策略
-                        # 问题:agent状态更新后即无法再次发起此策略
+                    if agent.previous_state == DialogueState.INITIALIZED or first_initiate_tags.intersection(user_tags):
+                        # 完全无交互历史的用户才使用此策略,但新用户接入即会产生“我已添加了你”的消息将Agent初始化
+                        # 因此存量用户无法使用该状态做实验
                         # TODO:增加基于对话历史的判断、策略去重;如果对话间隔过长需要使用长期记忆检索;在无长期记忆时,可采用用户添加时间来判断
                         resp = self._generate_active_greeting_message(agent, user_tags)
                     else:
@@ -255,7 +261,7 @@ class AgentService:
                             message_type = MessageType.VOICE
                         else:
                             message_type = MessageType.TEXT
-                        self._send_response(staff_id, user_id, resp, message_type)
+                        self._send_response(staff_id, user_id, resp, message_type, skip_check=True)
                         if self.limit_initiative_conversation_rate:
                             time.sleep(random.randint(10,20))
                     agent.persist_state()
@@ -267,7 +273,7 @@ class AgentService:
                 logger.debug("user: {}, do not initiate conversation".format(user_id))
 
     def _generate_active_greeting_message(self, agent: DialogueManager, user_tags: List[str]=None):
-        chat_config = agent.build_active_greeting_config()
+        chat_config = agent.build_active_greeting_config(user_tags)
         chat_response = self._call_chat_api(chat_config, ChatServiceType.OPENAI_COMPATIBLE)
         chat_response = self.sanitize_response(chat_response)
         if response := agent.generate_response(chat_response):

+ 13 - 2
dialogue_manager.py

@@ -617,8 +617,8 @@ class DialogueManager:
                 last_message_role, aggregated_message, ChatServiceType.COZE_CHAT))
         return messages
 
-    def build_active_greeting_config(self):
-        # FIXME: 这里的抽象不好
+    def build_active_greeting_config(self, user_tags: List[str]):
+        # FIXME: 这里的抽象不好,短期支持人为配置实验
         chat_config = {'user_id': self.user_id}
         prompt_context = self.get_prompt_context(None)
 
@@ -631,7 +631,18 @@ class DialogueManager:
             prompt_templates.GREETING_WITH_NAME_POETRY,
             prompt_templates.GREETING_WITH_AVATAR_STORY
         ]
+        # 默认随机选择
         selected_prompt = greeting_prompts[random.randint(0, len(greeting_prompts) - 1)]
+        # 实验配置
+        tag_to_greeting_map = {
+            '04W4-AA-1': prompt_templates.GREETING_WITH_NAME_POETRY,
+            '04W4-AA-2': prompt_templates.GREETING_WITH_AVATAR_STORY,
+            '04W4-AA-3': prompt_templates.GREETING_WITH_AVATAR_STORY,
+            '04W4-AA-4': prompt_templates.GREETING_WITH_IMAGE_GAME,
+        }
+        for tag in user_tags:
+            if tag in tag_to_greeting_map:
+                selected_prompt = tag_to_greeting_map[tag]
         prompt = selected_prompt.format(**prompt_context)
         user_message = {'role': 'user', 'content': prompt}
         messages = [system_message, user_message]