Explorar el Código

Update agent_service: add send_multimodal_response

StrayWarrior hace 2 semanas
padre
commit
87637f43c8
Se han modificado 1 ficheros con 23 adiciones y 6 borrados
  1. 23 6
      pqai_agent/agent_service.py

+ 23 - 6
pqai_agent/agent_service.py

@@ -311,24 +311,41 @@ class AgentService:
             agent.rollback_state()
             raise e
 
-    def send_response(self, staff_id, user_id, response, message_type: MessageType, skip_check=False):
-        logger.warning(f"staff[{staff_id}] user[{user_id}]: response[{message_type}] {response}")
-        current_ts = int(time.time() * 1000)
+    def can_send_to_user(self, staff_id, user_id) -> bool:
         user_tags = self.user_relation_manager.get_user_tags(user_id)
         white_list_tags = set(apollo_config.get_json_value("agent_response_whitelist_tags", []))
         hit_white_list_tags = len(set(user_tags).intersection(white_list_tags)) > 0
-        # FIXME(zhoutian)
-        # 测试期间临时逻辑,只发送特定的账号或特定用户
         staff_white_lists = set(apollo_config.get_json_value("agent_response_whitelist_staffs", []))
-        if not (staff_id in staff_white_lists or hit_white_list_tags or skip_check):
+        if not (staff_id in staff_white_lists or hit_white_list_tags):
             logger.warning(f"staff[{staff_id}] user[{user_id}]: skip reply")
+            return False
+        return True
+
+    def send_response(self, staff_id, user_id, response, message_type: MessageType, skip_check=False):
+        logger.warning(f"staff[{staff_id}] user[{user_id}]: response[{message_type}] {response}")
+        if not skip_check and not self.can_send_to_user(staff_id, user_id):
             return
+        current_ts = int(time.time() * 1000)
         self.send_rate_limiter.wait_for_sending(staff_id, response)
         self.send_queue.produce(
             MqMessage.build(message_type, MessageChannel.CORP_WECHAT,
                             staff_id, user_id, response, current_ts)
         )
 
+    def send_multimodal_response(self, staff_id, user_id, response: Dict, skip_check=False):
+        message_type = response["type"]
+        if message_type not in (MessageType.TEXT, MessageType.IMAGE_QW, MessageType.VOICE):
+            logger.error(f"staff[{staff_id}] user[{user_id}]: unsupported message type {message_type}")
+            return
+        if not skip_check and not self.can_send_to_user(staff_id, user_id):
+            return
+        current_ts = int(time.time() * 1000)
+        self.send_rate_limiter.wait_for_sending(staff_id, response)
+        self.send_queue.produce(
+            MqMessage.build(message_type, MessageChannel.CORP_WECHAT,
+                            staff_id, user_id, response["content"], current_ts)
+        )
+
     def _route_to_human_intervention(self, user_id: str, origin_message: MqMessage):
         """路由到人工干预"""
         self.human_queue.produce(MqMessage.build(