|
@@ -7,6 +7,7 @@ from datetime import datetime, timedelta
|
|
|
from typing import Dict, Optional, Tuple, Any
|
|
|
from unittest.mock import Mock, MagicMock
|
|
|
from agent_service import AgentService, MemoryQueueBackend
|
|
|
+from dialogue_manager import DialogueState, TimeContext
|
|
|
from message import MessageType, Message, MessageChannel
|
|
|
from user_manager import LocalUserManager
|
|
|
import time
|
|
@@ -51,6 +52,33 @@ def test_env():
|
|
|
|
|
|
return service, queues
|
|
|
|
|
|
+def test_agent_state_change(test_env):
|
|
|
+ service, _ = test_env
|
|
|
+ agent = service._get_agent_instance('staff_id_0', 'user_id_0')
|
|
|
+ assert agent.current_state == DialogueState.INITIALIZED
|
|
|
+ assert agent.previous_state == DialogueState.INITIALIZED
|
|
|
+
|
|
|
+ agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
|
|
|
+ assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
|
|
|
+ assert agent.previous_state == DialogueState.INITIALIZED
|
|
|
+
|
|
|
+ agent.do_state_change(DialogueState.GREETING)
|
|
|
+ assert agent.current_state == DialogueState.GREETING
|
|
|
+ assert agent.previous_state == DialogueState.INITIALIZED
|
|
|
+
|
|
|
+ agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
|
|
|
+ assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
|
|
|
+ assert agent.previous_state == DialogueState.GREETING
|
|
|
+
|
|
|
+ agent.do_state_change(DialogueState.CHITCHAT)
|
|
|
+ assert agent.current_state == DialogueState.CHITCHAT
|
|
|
+ assert agent.previous_state == DialogueState.GREETING
|
|
|
+
|
|
|
+ agent.rollback_state()
|
|
|
+ assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
|
|
|
+ assert agent.previous_state == DialogueState.GREETING
|
|
|
+
|
|
|
+
|
|
|
def test_response_sanitization(test_env):
|
|
|
case1 = '[2024-01-01 12:00:00] 你好'
|
|
|
ret1 = AgentService.sanitize_response(case1)
|
|
@@ -159,7 +187,9 @@ def test_initiative_conversation(test_env):
|
|
|
|
|
|
# 设置Agent需要主动发起对话
|
|
|
agent = service._get_agent_instance('staff_id_0', "user_id_0")
|
|
|
- agent.should_initiate_conversation = Mock(return_value=(True, MagicMock()))
|
|
|
+ # agent.should_initiate_conversation = Mock(return_value=(True, MagicMock()))
|
|
|
+ # 发起对话有时间限制
|
|
|
+ agent.get_time_context = Mock(return_value=TimeContext.MORNING)
|
|
|
|
|
|
service._check_initiative_conversations()
|
|
|
|