瀏覽代碼

Update agent_service: multimodal response support

StrayWarrior 1 周之前
父節點
當前提交
f0f1643f4a
共有 1 個文件被更改,包括 28 次插入12 次删除
  1. 28 12
      pqai_agent/agent_service.py

+ 28 - 12
pqai_agent/agent_service.py

@@ -295,15 +295,7 @@ class AgentService:
                 # 先更新用户画像再处理回复
                 # 先更新用户画像再处理回复
                 self._update_user_profile(user_id, user_profile, agent.dialogue_history[-10:])
                 self._update_user_profile(user_id, user_profile, agent.dialogue_history[-10:])
                 resp = self.get_chat_response(agent, message_text)
                 resp = self.get_chat_response(agent, message_text)
-                if resp:
-                    recent_dialogue = agent.dialogue_history[-10:]
-                    agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
-                    if len(recent_dialogue) < 2 or staff_id not in agent_voice_whitelist:
-                        message_type = MessageType.TEXT
-                    else:
-                        message_type = self.response_type_detector.detect_type(
-                            recent_dialogue[:-1], recent_dialogue[-1], enable_random=True)
-                    self.send_response(staff_id, user_id, resp, message_type)
+                self.send_responses(agent, resp)
             else:
             else:
                 logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
                 logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
             # 当前消息处理成功,commit并持久化agent状态
             # 当前消息处理成功,commit并持久化agent状态
@@ -312,6 +304,25 @@ class AgentService:
             agent.rollback_state()
             agent.rollback_state()
             raise e
             raise e
 
 
+    def send_responses(self, agent: DialogueManager, contents: List[Dict]):
+        staff_id = agent.staff_id
+        user_id = agent.user_id
+        recent_dialogue = agent.dialogue_history[-10:]
+        agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
+        for item in contents:
+            if item["type"] == MessageType.TEXT:
+                if staff_id in agent_voice_whitelist:
+                    message_type = self.response_type_detector.detect_type(
+                        recent_dialogue, item["content"], enable_random=True)
+                    item["type"] = message_type
+        if contents:
+            current_ts = int(time.time())
+            for response in contents:
+                self.send_multimodal_response(staff_id, user_id, response, skip_check=True)
+            agent.update_last_active_interaction_time(current_ts)
+        else:
+            logger.debug(f"staff[{staff_id}], user[{user_id}]: no messages to send")
+
     def can_send_to_user(self, staff_id, user_id) -> bool:
     def can_send_to_user(self, staff_id, user_id) -> bool:
         user_tags = self.user_relation_manager.get_user_tags(user_id)
         user_tags = self.user_relation_manager.get_user_tags(user_id)
         white_list_tags = set(apollo_config.get_json_value("agent_response_whitelist_tags", []))
         white_list_tags = set(apollo_config.get_json_value("agent_response_whitelist_tags", []))
@@ -335,6 +346,7 @@ class AgentService:
 
 
     def send_multimodal_response(self, staff_id, user_id, response: Dict, skip_check=False):
     def send_multimodal_response(self, staff_id, user_id, response: Dict, skip_check=False):
         message_type = response["type"]
         message_type = response["type"]
+        logger.warning(f"staff[{staff_id}] user[{user_id}]: response[{message_type}] {response}")
         if message_type not in (MessageType.TEXT, MessageType.IMAGE_QW, MessageType.VOICE):
         if message_type not in (MessageType.TEXT, MessageType.IMAGE_QW, MessageType.VOICE):
             logger.error(f"staff[{staff_id}] user[{user_id}]: unsupported message type {message_type}")
             logger.error(f"staff[{staff_id}] user[{user_id}]: unsupported message type {message_type}")
             return
             return
@@ -410,14 +422,15 @@ class AgentService:
         # 问题在于,如果每次创建出新的PushTaskWorkerPool,在上次任务有未处理完的消息即退出时,会有未处理的消息堆积
         # 问题在于,如果每次创建出新的PushTaskWorkerPool,在上次任务有未处理完的消息即退出时,会有未处理的消息堆积
         push_task_worker_pool.wait_to_finish()
         push_task_worker_pool.wait_to_finish()
 
 
-    def get_chat_response(self, agent: DialogueManager, user_message: Optional[str]) -> Union[str, List[Dict]]:
+    def get_chat_response(self, agent: DialogueManager, user_message: Optional[str]) -> List[Dict]:
         chat_agent_ver = self.config.get('system', {}).get('chat_agent_version', 1)
         chat_agent_ver = self.config.get('system', {}).get('chat_agent_version', 1)
         if chat_agent_ver == 2:
         if chat_agent_ver == 2:
             return self._get_chat_response_v2(agent)
             return self._get_chat_response_v2(agent)
         else:
         else:
-            return self._get_chat_response_v1(agent, user_message)
+            text_resp = self._get_chat_response_v1(agent, user_message)
+            return [{"type": MessageType.TEXT, "content": text_resp}] if text_resp else []
 
 
-    def _get_chat_response_v1(self, agent: DialogueManager, user_message: Optional[str]) -> str:
+    def _get_chat_response_v1(self, agent: DialogueManager, user_message: Optional[str]) -> Optional[str]:
         chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
         chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
         config_for_logging = chat_config.copy()
         config_for_logging = chat_config.copy()
         config_for_logging['messages'] = config_for_logging['messages'][-20:]
         config_for_logging['messages'] = config_for_logging['messages'][-20:]
@@ -444,6 +457,9 @@ class AgentService:
         for chat_response in chat_responses:
         for chat_response in chat_responses:
             if response := main_agent.generate_multimodal_response(chat_response):
             if response := main_agent.generate_multimodal_response(chat_response):
                 final_responses.append(response)
                 final_responses.append(response)
+            else:
+                # 存在非法/结束消息,清空待发消息
+                final_responses.clear()
         return final_responses
         return final_responses
 
 
     def _call_chat_api(self, chat_config: Dict, chat_service_type: ChatServiceType) -> str:
     def _call_chat_api(self, chat_config: Dict, chat_service_type: ChatServiceType) -> str: