Browse Source

Update push_service: change to multimodal messages

StrayWarrior 1 month ago
parent
commit
cad5e8cd99
1 changed files with 37 additions and 22 deletions
  1. 37 22
      pqai_agent/push_service.py

+ 37 - 22
pqai_agent/push_service.py

@@ -6,7 +6,7 @@ from datetime import datetime
 from enum import Enum
 from concurrent.futures import ThreadPoolExecutor
 from threading import Thread
-from typing import Optional, Dict
+from typing import Optional, Dict, List
 
 import rocketmq
 from rocketmq import ClientConfiguration, Credentials, SimpleConsumer, FilterExpression
@@ -30,7 +30,7 @@ def generate_task_rmq_message(topic: str, staff_id: str, user_id: str, task_type
         'staff_id': staff_id,
         'user_id': user_id,
         'task_type': task_type.value,
-        # FIXME: 需要支持多模态消息
+        # NOTE:通过传入JSON支持多模态消息
         'content': content or '',
         'timestamp': int(time.time() * 1000),
     }, ensure_ascii=False).encode('utf-8')
@@ -142,26 +142,41 @@ class PushTaskWorkerPool:
                 logger.debug(f"user[{user_id}], do not initiate conversation")
                 self.consumer.ack(msg)
                 return
-            content = task['content']
+            contents: List[Dict] = json.loads(task['content'])
+            if not contents:
+                logger.debug(f"staff[{staff_id}], user[{user_id}]: empty content, do not send")
+                self.consumer.ack(msg)
+                return
             recent_dialogue = agent.dialogue_history[-10:]
             agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
-            # FIXME(zhoutian): 不应该再由agent控制,或者agent和API共享同一配置
-            if len(recent_dialogue) < 2 or staff_id not in agent_voice_whitelist:
-                message_type = MessageType.TEXT
-            else:
-                message_type = self.agent_service.response_type_detector.detect_type(
-                    recent_dialogue, content, enable_random=True)
-            response = agent.generate_response(content)
-            if response:
-                current_ts = int(time.time())
-                with self.agent_service.AgentDBSession() as session:
-                    msg_list = [{'type': MessageType.TEXT.value, 'content': response}]
-                    record = AgentPushRecord(staff_id=staff_id, user_id=user_id,
-                                             content=json.dumps(msg_list, ensure_ascii=False),
-                                             timestamp=current_ts)
-                    session.add(record)
-                    session.commit()
-                self.agent_service.send_response(staff_id, user_id, response, message_type, skip_check=True)
+            messages_to_send = []
+            for item in contents:
+                if item["type"] == "text":
+                    if staff_id not in agent_voice_whitelist:
+                        message_type = MessageType.TEXT
+                    else:
+                        message_type = self.agent_service.response_type_detector.detect_type(
+                            recent_dialogue, item["content"], enable_random=True)
+                    response = agent.generate_response(item["content"])
+                    if response:
+                        messages_to_send.append({'type': message_type, 'content': response})
+                else:
+                    message_type = MessageType.from_str(item["type"])
+                    response = agent.generate_multimodal_response(item)
+                    if response:
+                        item["type"] = message_type
+                        messages_to_send.append(item)
+            current_ts = int(time.time())
+            with self.agent_service.AgentDBSession() as session:
+                msg_list = []
+                record = AgentPushRecord(staff_id=staff_id, user_id=user_id,
+                                         content=json.dumps(msg_list, ensure_ascii=False),
+                                         timestamp=current_ts)
+                session.add(record)
+                session.commit()
+            if messages_to_send:
+                for response in messages_to_send:
+                    self.agent_service.send_multimodal_response(staff_id, user_id, response, skip_check=True)
                 agent.update_last_active_interaction_time(current_ts)
             else:
                 logger.debug(f"staff[{staff_id}], user[{user_id}]: generate empty response")
@@ -184,8 +199,8 @@ class PushTaskWorkerPool:
                 )
             )
             if message_to_user:
-                rmq_message = generate_task_rmq_message(self.rmq_topic, staff_id, user_id, TaskType.SEND, message_to_user)
-                logger.debug(f"send message: {rmq_message.body.decode('utf-8')}")
+                rmq_message = generate_task_rmq_message(
+                    self.rmq_topic, staff_id, user_id, TaskType.SEND, json.dumps(message_to_user))
                 self.producer.send(rmq_message)
             else:
                 logger.info(f"staff[{staff_id}], user[{user_id}]: no push message generated")