123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- #! /usr/bin/env python
- # -*- coding: utf-8 -*-
- # vim:fenc=utf-8
- import pytest
- from datetime import datetime, timedelta
- from typing import Dict, Optional, Tuple, Any
- from unittest.mock import Mock, MagicMock
- from agent_service import AgentService, MemoryQueueBackend
- from dialogue_manager import DialogueState, TimeContext
- from message import MessageType, Message, MessageChannel
- from user_manager import LocalUserManager
- import time
- import logging
- class TestMessageQueues:
- """测试用消息队列实现"""
- def __init__(self, receive_queue, send_queue, human_queue):
- self.receive_queue = receive_queue
- self.send_queue = send_queue
- self.human_queue = human_queue
- @pytest.fixture
- def test_env():
- """测试环境初始化"""
- logging.getLogger().setLevel(logging.DEBUG)
- user_manager = LocalUserManager()
- user_relation_manager = Mock()
- user_relation_manager.get_user_tags = Mock(return_value=['AgentTest1'])
- user_relation_manager.list_staff_users = Mock(return_value=[{'staff_id': 'staff_id_0', 'user_id': 'user_id_0'}])
- receive_queue = MemoryQueueBackend()
- send_queue = MemoryQueueBackend()
- human_queue = MemoryQueueBackend()
- queues = TestMessageQueues(receive_queue, send_queue, human_queue)
- # 创建Agent服务实例
- service = AgentService(
- receive_backend=receive_queue,
- send_backend=send_queue,
- human_backend=human_queue,
- user_manager=user_manager,
- user_relation_manager=user_relation_manager
- )
- service.user_profile_extractor.extract_profile_info = Mock(return_value=None)
- service.limit_initiative_conversation_rate = False
- # 替换LLM调用为模拟响应
- service._call_chat_api = Mock(return_value="模拟响应")
- return service, queues
- def test_agent_state_change(test_env):
- service, _ = test_env
- agent = service._get_agent_instance('staff_id_0', 'user_id_0')
- assert agent.current_state == DialogueState.INITIALIZED
- assert agent.previous_state == DialogueState.INITIALIZED
- agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
- assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
- assert agent.previous_state == DialogueState.INITIALIZED
- agent.do_state_change(DialogueState.GREETING)
- assert agent.current_state == DialogueState.GREETING
- assert agent.previous_state == DialogueState.INITIALIZED
- agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
- assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
- assert agent.previous_state == DialogueState.GREETING
- agent.do_state_change(DialogueState.CHITCHAT)
- assert agent.current_state == DialogueState.CHITCHAT
- assert agent.previous_state == DialogueState.GREETING
- agent.rollback_state()
- assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
- assert agent.previous_state == DialogueState.GREETING
- agent.do_state_change(DialogueState.CHITCHAT)
- assert agent.current_state == DialogueState.CHITCHAT
- assert agent.previous_state == DialogueState.GREETING
- agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
- agent.do_state_change(DialogueState.CHITCHAT)
- assert agent.current_state == DialogueState.CHITCHAT
- assert agent.previous_state == DialogueState.CHITCHAT
- agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
- agent.do_state_change(DialogueState.CHITCHAT)
- assert agent.state_backup == (DialogueState.MESSAGE_AGGREGATING, DialogueState.CHITCHAT)
- agent.rollback_state()
- assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
- def test_response_sanitization(test_env):
- case1 = '[2024-01-01 12:00:00] 你好'
- ret1 = AgentService.sanitize_response(case1)
- assert ret1 == '你好'
- case1 = '2024-01-01 12:00:00 你好'
- ret2 = AgentService.sanitize_response(case1)
- assert ret2 == '你好'
- def test_normal_conversation_flow(test_env):
- """测试正常对话流程"""
- service, queues = test_env
- service._get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
- # 准备测试消息
- test_msg = Message.build(
- MessageType.TEXT, MessageChannel.CORP_WECHAT,
- 'user_id_0', 'staff_id_0', '你好', int(time.time() * 1000))
- queues.receive_queue.produce(test_msg)
- # 处理消息
- message = service.receive_queue.consume()
- if message:
- service.process_single_message(message)
- # 验证响应消息
- sent_msg = queues.send_queue.consume()
- assert sent_msg is not None
- assert sent_msg.receiver == "user_id_0"
- assert "模拟响应" in sent_msg.content
- def test_aggregated_conversation_flow(test_env):
- """测试聚合对话流程"""
- service, queues = test_env
- service._get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 1
- # 准备测试消息
- ts_begin = int(time.time() * 1000)
- test_msg = Message.build(
- MessageType.TEXT, MessageChannel.CORP_WECHAT,
- 'user_id_0', 'staff_id_0', '你好', ts_begin)
- queues.receive_queue.produce(test_msg)
- test_msg = Message.build(
- MessageType.TEXT, MessageChannel.CORP_WECHAT,
- 'user_id_0', 'staff_id_0', '我是老李', ts_begin + 500)
- queues.receive_queue.produce(test_msg)
- # 处理消息
- message = service.receive_queue.consume()
- if message:
- service.process_single_message(message)
- # 验证第一次响应消息
- sent_msg = queues.send_queue.consume()
- assert sent_msg is None
- message = service.receive_queue.consume()
- if message:
- service.process_single_message(message)
- # 验证第二次响应消息
- sent_msg = queues.send_queue.consume()
- assert sent_msg is None
- # 模拟定时器产生空消息触发响应
- service.process_single_message(Message.build(
- MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
- 'user_id_0', 'staff_id_0', None, ts_begin + 2000
- ))
- # 验证第三次响应消息
- sent_msg = queues.send_queue.consume()
- assert sent_msg is not None
- assert sent_msg.receiver == "user_id_0"
- assert "模拟响应" in sent_msg.content
- def test_human_intervention_trigger(test_env):
- """测试触发人工干预"""
- service, queues = test_env
- service._get_agent_instance('staff_id_0',"user_id_0").message_aggregation_sec = 0
- # 准备需要人工干预的消息
- test_msg = Message.build(
- MessageType.TEXT, MessageChannel.CORP_WECHAT,
- "user_id_0", "staff_id_0",
- "我需要帮助!", int(time.time() * 1000)
- )
- queues.receive_queue.produce(test_msg)
- # 处理消息
- message = service.receive_queue.consume()
- if message:
- service.process_single_message(message)
- # 验证人工队列消息
- human_msg = queues.human_queue.consume()
- # 由于相关逻辑未启用,临时关闭该测试
- return
- assert human_msg is not None
- assert human_msg.sender == "user_id_0"
- assert "用户对话需人工介入" in human_msg.content
- def test_initiative_conversation(test_env):
- """测试主动发起对话"""
- service, queues = test_env
- service._get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
- service._call_chat_api = Mock(return_value="主动发起模拟消息")
- # 设置Agent需要主动发起对话
- agent = service._get_agent_instance('staff_id_0', "user_id_0")
- # agent.should_initiate_conversation = Mock(return_value=(True, MagicMock()))
- # 发起对话有时间限制
- agent.get_time_context = Mock(return_value=TimeContext.MORNING)
- service._check_initiative_conversations()
- # 验证主动发起的消息
- sent_msg = queues.send_queue.consume()
- assert sent_msg is not None
- assert "主动发起" in sent_msg.content
|