瀏覽代碼

Fix response type detection

StrayWarrior 1 周之前
父節點
當前提交
a8cf7255d6
共有 2 個文件被更改,包括 7 次插入5 次删除
  1. 4 3
      pqai_agent/agent_service.py
  2. 3 2
      pqai_agent/push_service.py

+ 4 - 3
pqai_agent/agent_service.py

@@ -309,14 +309,15 @@ class AgentService:
         user_id = agent.user_id
         recent_dialogue = agent.dialogue_history[-10:]
         agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
+        current_ts = int(time.time())
         for item in contents:
+            item["timestamp"] = current_ts * 1000
             if item["type"] == MessageType.TEXT:
-                if staff_id in agent_voice_whitelist:
+                if staff_id in agent_voice_whitelist or True:
                     message_type = self.response_type_detector.detect_type(
-                        recent_dialogue, item["content"], enable_random=True)
+                        recent_dialogue, item, enable_random=True)
                     item["type"] = message_type
         if contents:
-            current_ts = int(time.time())
             for response in contents:
                 self.send_multimodal_response(staff_id, user_id, response, skip_check=True)
             agent.update_last_active_interaction_time(current_ts)

+ 3 - 2
pqai_agent/push_service.py

@@ -150,13 +150,15 @@ class PushTaskWorkerPool:
             recent_dialogue = agent.dialogue_history[-10:]
             agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
             messages_to_send = []
+            current_ts = int(time.time())
             for item in contents:
+                item["timestamp"] = current_ts * 1000
                 if item["type"] == "text":
                     if staff_id not in agent_voice_whitelist:
                         message_type = MessageType.TEXT
                     else:
                         message_type = self.agent_service.response_type_detector.detect_type(
-                            recent_dialogue, item["content"], enable_random=True)
+                            recent_dialogue, item, enable_random=True)
                     response = agent.generate_response(item["content"])
                     if response:
                         messages_to_send.append({'type': message_type, 'content': response})
@@ -166,7 +168,6 @@ class PushTaskWorkerPool:
                     if response:
                         item["type"] = message_type
                         messages_to_send.append(item)
-            current_ts = int(time.time())
             with self.agent_service.AgentDBSession() as session:
                 msg_list = [{"type": msg["type"].value, "content": msg["content"]} for msg in messages_to_send]
                 record = AgentPushRecord(staff_id=staff_id, user_id=user_id,