|
@@ -7,7 +7,7 @@ from datetime import datetime, timedelta
|
|
|
from typing import Dict, Optional, Tuple, Any
|
|
|
from unittest.mock import Mock, MagicMock
|
|
|
from agent_service import AgentService, MemoryQueueBackend
|
|
|
-from message import MessageType
|
|
|
+from message import MessageType, Message, MessageChannel
|
|
|
from user_manager import LocalUserManager
|
|
|
import time
|
|
|
import logging
|
|
@@ -49,15 +49,12 @@ def test_env():
|
|
|
def test_normal_conversation_flow(test_env):
|
|
|
"""测试正常对话流程"""
|
|
|
service, queues = test_env
|
|
|
- service._get_agent_instance("user_id_0").message_aggregation_sec = 0
|
|
|
+ service._get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
|
|
|
|
|
|
# 准备测试消息
|
|
|
- test_msg = {
|
|
|
- "user_id": "user_id_0",
|
|
|
- "type": MessageType.TEXT,
|
|
|
- "text": "你好",
|
|
|
- "timestamp": int(time.time() * 1000),
|
|
|
- }
|
|
|
+ test_msg = Message.build(
|
|
|
+ MessageType.TEXT, MessageChannel.CORP_WECHAT,
|
|
|
+ 'user_id_0', 'staff_id_0', '你好', int(time.time() * 1000))
|
|
|
queues.receive_queue.produce(test_msg)
|
|
|
|
|
|
# 处理消息
|
|
@@ -68,29 +65,23 @@ def test_normal_conversation_flow(test_env):
|
|
|
# 验证响应消息
|
|
|
sent_msg = queues.send_queue.consume()
|
|
|
assert sent_msg is not None
|
|
|
- assert sent_msg["user_id"] == "user_id_0"
|
|
|
- assert "模拟响应" in sent_msg["text"]
|
|
|
+ assert sent_msg.receiver == "user_id_0"
|
|
|
+ assert "模拟响应" in sent_msg.content
|
|
|
|
|
|
def test_aggregated_conversation_flow(test_env):
|
|
|
"""测试聚合对话流程"""
|
|
|
service, queues = test_env
|
|
|
- service._get_agent_instance("user_id_0").message_aggregation_sec = 1
|
|
|
+ service._get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 1
|
|
|
|
|
|
# 准备测试消息
|
|
|
ts_begin = int(time.time() * 1000)
|
|
|
- test_msg = {
|
|
|
- "user_id": "user_id_0",
|
|
|
- "type": MessageType.TEXT,
|
|
|
- "text": "你好",
|
|
|
- "timestamp": ts_begin,
|
|
|
- }
|
|
|
+ test_msg = Message.build(
|
|
|
+ MessageType.TEXT, MessageChannel.CORP_WECHAT,
|
|
|
+ 'user_id_0', 'staff_id_0', '你好', ts_begin)
|
|
|
queues.receive_queue.produce(test_msg)
|
|
|
- test_msg = {
|
|
|
- "user_id": "user_id_0",
|
|
|
- "type": MessageType.TEXT,
|
|
|
- "text": "我是老李",
|
|
|
- "timestamp": ts_begin + 0.5 * 1000,
|
|
|
- }
|
|
|
+ test_msg = Message.build(
|
|
|
+ MessageType.TEXT, MessageChannel.CORP_WECHAT,
|
|
|
+ 'user_id_0', 'staff_id_0', '我是老李', ts_begin + 500)
|
|
|
queues.receive_queue.produce(test_msg)
|
|
|
|
|
|
# 处理消息
|
|
@@ -110,29 +101,27 @@ def test_aggregated_conversation_flow(test_env):
|
|
|
assert sent_msg is None
|
|
|
|
|
|
# 模拟定时器产生空消息触发响应
|
|
|
- service.process_single_message({
|
|
|
- "user_id": "user_id_0",
|
|
|
- "type": MessageType.AGGREGATION_TRIGGER,
|
|
|
- "timestamp": ts_begin + 2 * 1000
|
|
|
- })
|
|
|
+ service.process_single_message(Message.build(
|
|
|
+ MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
|
|
|
+ 'user_id_0', 'staff_id_0', None, ts_begin + 2000
|
|
|
+ ))
|
|
|
# 验证第三次响应消息
|
|
|
sent_msg = queues.send_queue.consume()
|
|
|
assert sent_msg is not None
|
|
|
- assert sent_msg["user_id"] == "user_id_0"
|
|
|
- assert "模拟响应" in sent_msg["text"]
|
|
|
+ assert sent_msg.receiver == "user_id_0"
|
|
|
+ assert "模拟响应" in sent_msg.content
|
|
|
|
|
|
def test_human_intervention_trigger(test_env):
|
|
|
"""测试触发人工干预"""
|
|
|
service, queues = test_env
|
|
|
- service._get_agent_instance("user_id_0").message_aggregation_sec = 0
|
|
|
+ service._get_agent_instance('staff_id_0',"user_id_0").message_aggregation_sec = 0
|
|
|
|
|
|
# 准备需要人工干预的消息
|
|
|
- test_msg = {
|
|
|
- "user_id": "user_id_0",
|
|
|
- "type": MessageType.TEXT,
|
|
|
- "text": "我需要帮助!",
|
|
|
- "timestamp": int(time.time() * 1000),
|
|
|
- }
|
|
|
+ test_msg = Message.build(
|
|
|
+ MessageType.TEXT, MessageChannel.CORP_WECHAT,
|
|
|
+ "user_id_0", "staff_id_0",
|
|
|
+ "我需要帮助!", int(time.time() * 1000)
|
|
|
+ )
|
|
|
queues.receive_queue.produce(test_msg)
|
|
|
|
|
|
# 处理消息
|
|
@@ -143,17 +132,17 @@ def test_human_intervention_trigger(test_env):
|
|
|
# 验证人工队列消息
|
|
|
human_msg = queues.human_queue.consume()
|
|
|
assert human_msg is not None
|
|
|
- assert human_msg["user_id"] == "user_id_0"
|
|
|
- assert "state" in human_msg
|
|
|
+ assert human_msg.sender == "user_id_0"
|
|
|
+ assert "用户对话需人工介入" in human_msg.content
|
|
|
|
|
|
def test_initiative_conversation(test_env):
|
|
|
"""测试主动发起对话"""
|
|
|
service, queues = test_env
|
|
|
- service._get_agent_instance("user_id_0").message_aggregation_sec = 0
|
|
|
+ service._get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
|
|
|
service._call_chat_api = Mock(return_value="主动发起模拟消息")
|
|
|
|
|
|
# 设置Agent需要主动发起对话
|
|
|
- agent = service._get_agent_instance("user_id_0")
|
|
|
+ agent = service._get_agent_instance('staff_id_0', "user_id_0")
|
|
|
agent.should_initiate_conversation = Mock(return_value=(True, MagicMock()))
|
|
|
|
|
|
service._check_initiative_conversations()
|
|
@@ -161,4 +150,4 @@ def test_initiative_conversation(test_env):
|
|
|
# 验证主动发起的消息
|
|
|
sent_msg = queues.send_queue.consume()
|
|
|
assert sent_msg is not None
|
|
|
- assert "主动发起" in sent_msg["text"]
|
|
|
+ assert "主动发起" in sent_msg.content
|