浏览代码

Add unit test for dialogue state change

StrayWarrior 3 天之前
父节点
当前提交
cd69e87bea
共有 2 个文件被更改,包括 35 次插入4 次删除
  1. 4 3
      dialogue_manager.py
  2. 31 1
      unit_test.py

+ 4 - 3
dialogue_manager.py

@@ -97,6 +97,7 @@ class DialogueManager:
         self.dialogue_history = []
         self.dialogue_history = []
         self.user_profile = self.user_manager.get_user_profile(user_id)
         self.user_profile = self.user_manager.get_user_profile(user_id)
         self.staff_profile = self.user_manager.get_staff_profile(staff_id)
         self.staff_profile = self.user_manager.get_staff_profile(staff_id)
+        # FIXME(zhoutian): last_interaction_time也需要回滚
         self.last_interaction_time = 0
         self.last_interaction_time = 0
         self.consecutive_clarifications = 0
         self.consecutive_clarifications = 0
         self.complex_request_counter = 0
         self.complex_request_counter = 0
@@ -156,7 +157,7 @@ class DialogueManager:
     def do_state_change(self, state: DialogueState):
     def do_state_change(self, state: DialogueState):
         self.state_backup = (self.current_state, self.previous_state)
         self.state_backup = (self.current_state, self.previous_state)
         if self.current_state == DialogueState.MESSAGE_AGGREGATING:
         if self.current_state == DialogueState.MESSAGE_AGGREGATING:
-            # 不需要更新previous_state
+            # MESSAGE_AGGREGATING不能成为previous_state,仅使用state_backup做回退
             self.current_state = state
             self.current_state = state
         else:
         else:
             self.previous_state = self.current_state
             self.previous_state = self.current_state
@@ -286,7 +287,7 @@ class DialogueManager:
             event = {
             event = {
                 "timestamp": int(time.time() * 1000),
                 "timestamp": int(time.time() * 1000),
                 "reason": reason,
                 "reason": reason,
-                "dialogue_context": self.history_dialogue_service.get_dialogue_history(self.staff_id, self.user_id, 60)
+                "dialogue_context": self.dialogue_history[-10:]
             }
             }
 
 
             # 更新用户资料中的人工介入历史
             # 更新用户资料中的人工介入历史
@@ -310,7 +311,7 @@ class DialogueManager:
         """
         """
 
 
         # 添加最近的对话记录
         # 添加最近的对话记录
-        recent_dialogues = self.history_dialogue_service.get_dialogue_history(self.staff_id, self.user_id, 10)
+        recent_dialogues = self.dialogue_history[-10:]
         for dialogue in recent_dialogues:
         for dialogue in recent_dialogues:
             alert_message += f"\n{dialogue['role']}: {dialogue['content']}"
             alert_message += f"\n{dialogue['role']}: {dialogue['content']}"
 
 

+ 31 - 1
unit_test.py

@@ -7,6 +7,7 @@ from datetime import datetime, timedelta
 from typing import Dict, Optional, Tuple, Any
 from typing import Dict, Optional, Tuple, Any
 from unittest.mock import Mock, MagicMock
 from unittest.mock import Mock, MagicMock
 from agent_service import AgentService, MemoryQueueBackend
 from agent_service import AgentService, MemoryQueueBackend
+from dialogue_manager import DialogueState, TimeContext
 from message import MessageType, Message, MessageChannel
 from message import MessageType, Message, MessageChannel
 from user_manager import LocalUserManager
 from user_manager import LocalUserManager
 import time
 import time
@@ -51,6 +52,33 @@ def test_env():
 
 
     return service, queues
     return service, queues
 
 
+def test_agent_state_change(test_env):
+    service, _ = test_env
+    agent = service._get_agent_instance('staff_id_0', 'user_id_0')
+    assert agent.current_state == DialogueState.INITIALIZED
+    assert agent.previous_state == DialogueState.INITIALIZED
+
+    agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
+    assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
+    assert agent.previous_state == DialogueState.INITIALIZED
+
+    agent.do_state_change(DialogueState.GREETING)
+    assert agent.current_state == DialogueState.GREETING
+    assert agent.previous_state == DialogueState.INITIALIZED
+
+    agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
+    assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
+    assert agent.previous_state == DialogueState.GREETING
+
+    agent.do_state_change(DialogueState.CHITCHAT)
+    assert agent.current_state == DialogueState.CHITCHAT
+    assert agent.previous_state == DialogueState.GREETING
+
+    agent.rollback_state()
+    assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
+    assert agent.previous_state == DialogueState.GREETING
+
+
 def test_response_sanitization(test_env):
 def test_response_sanitization(test_env):
     case1 = '[2024-01-01 12:00:00] 你好'
     case1 = '[2024-01-01 12:00:00] 你好'
     ret1 = AgentService.sanitize_response(case1)
     ret1 = AgentService.sanitize_response(case1)
@@ -159,7 +187,9 @@ def test_initiative_conversation(test_env):
 
 
     # 设置Agent需要主动发起对话
     # 设置Agent需要主动发起对话
     agent = service._get_agent_instance('staff_id_0', "user_id_0")
     agent = service._get_agent_instance('staff_id_0', "user_id_0")
-    agent.should_initiate_conversation = Mock(return_value=(True, MagicMock()))
+    # agent.should_initiate_conversation = Mock(return_value=(True, MagicMock()))
+    # 发起对话有时间限制
+    agent.get_time_context = Mock(return_value=TimeContext.MORNING)
 
 
     service._check_initiative_conversations()
     service._check_initiative_conversations()