unit_test.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import pytest
  5. from datetime import datetime, timedelta
  6. from typing import Dict, Optional, Tuple, Any
  7. from unittest.mock import Mock, MagicMock
  8. from agent_service import AgentService, MemoryQueueBackend
  9. from message import MessageType
  10. from user_manager import LocalUserManager
  11. import time
  12. import logging
  13. class TestMessageQueues:
  14. """测试用消息队列实现"""
  15. def __init__(self, receive_queue, send_queue, human_queue):
  16. self.receive_queue = receive_queue
  17. self.send_queue = send_queue
  18. self.human_queue = human_queue
  19. @pytest.fixture
  20. def test_env():
  21. """测试环境初始化"""
  22. logging.getLogger().setLevel(logging.DEBUG)
  23. user_manager = LocalUserManager()
  24. receive_queue = MemoryQueueBackend()
  25. send_queue = MemoryQueueBackend()
  26. human_queue = MemoryQueueBackend()
  27. queues = TestMessageQueues(receive_queue, send_queue, human_queue)
  28. # 创建Agent服务实例
  29. service = AgentService(
  30. receive_backend=receive_queue,
  31. send_backend=send_queue,
  32. human_backend=human_queue,
  33. user_manager=user_manager
  34. )
  35. service.user_profile_extractor.extract_profile_info = Mock(return_value=None)
  36. # 替换LLM调用为模拟响应
  37. service._call_chat_api = Mock(return_value="模拟响应")
  38. return service, queues
  39. def test_normal_conversation_flow(test_env):
  40. """测试正常对话流程"""
  41. service, queues = test_env
  42. service._get_agent_instance("user_id_0").message_aggregation_sec = 0
  43. # 准备测试消息
  44. test_msg = {
  45. "user_id": "user_id_0",
  46. "type": MessageType.TEXT,
  47. "text": "你好",
  48. "timestamp": int(time.time() * 1000),
  49. }
  50. queues.receive_queue.produce(test_msg)
  51. # 处理消息
  52. message = service.receive_queue.consume()
  53. if message:
  54. service.process_single_message(message)
  55. # 验证响应消息
  56. sent_msg = queues.send_queue.consume()
  57. assert sent_msg is not None
  58. assert sent_msg["user_id"] == "user_id_0"
  59. assert "模拟响应" in sent_msg["text"]
  60. def test_aggregated_conversation_flow(test_env):
  61. """测试聚合对话流程"""
  62. service, queues = test_env
  63. service._get_agent_instance("user_id_0").message_aggregation_sec = 1
  64. # 准备测试消息
  65. ts_begin = int(time.time() * 1000)
  66. test_msg = {
  67. "user_id": "user_id_0",
  68. "type": MessageType.TEXT,
  69. "text": "你好",
  70. "timestamp": ts_begin,
  71. }
  72. queues.receive_queue.produce(test_msg)
  73. test_msg = {
  74. "user_id": "user_id_0",
  75. "type": MessageType.TEXT,
  76. "text": "我是老李",
  77. "timestamp": ts_begin + 0.5 * 1000,
  78. }
  79. queues.receive_queue.produce(test_msg)
  80. # 处理消息
  81. message = service.receive_queue.consume()
  82. if message:
  83. service.process_single_message(message)
  84. # 验证第一次响应消息
  85. sent_msg = queues.send_queue.consume()
  86. assert sent_msg is None
  87. message = service.receive_queue.consume()
  88. if message:
  89. service.process_single_message(message)
  90. # 验证第二次响应消息
  91. sent_msg = queues.send_queue.consume()
  92. assert sent_msg is None
  93. # 模拟定时器产生空消息触发响应
  94. service.process_single_message({
  95. "user_id": "user_id_0",
  96. "type": MessageType.AGGREGATION_TRIGGER,
  97. "timestamp": ts_begin + 2 * 1000
  98. })
  99. # 验证第三次响应消息
  100. sent_msg = queues.send_queue.consume()
  101. assert sent_msg is not None
  102. assert sent_msg["user_id"] == "user_id_0"
  103. assert "模拟响应" in sent_msg["text"]
  104. def test_human_intervention_trigger(test_env):
  105. """测试触发人工干预"""
  106. service, queues = test_env
  107. service._get_agent_instance("user_id_0").message_aggregation_sec = 0
  108. # 准备需要人工干预的消息
  109. test_msg = {
  110. "user_id": "user_id_0",
  111. "type": MessageType.TEXT,
  112. "text": "我需要帮助!",
  113. "timestamp": int(time.time() * 1000),
  114. }
  115. queues.receive_queue.produce(test_msg)
  116. # 处理消息
  117. message = service.receive_queue.consume()
  118. if message:
  119. service.process_single_message(message)
  120. # 验证人工队列消息
  121. human_msg = queues.human_queue.consume()
  122. assert human_msg is not None
  123. assert human_msg["user_id"] == "user_id_0"
  124. assert "state" in human_msg
  125. def test_initiative_conversation(test_env):
  126. """测试主动发起对话"""
  127. service, queues = test_env
  128. service._get_agent_instance("user_id_0").message_aggregation_sec = 0
  129. service._call_chat_api = Mock(return_value="主动发起模拟消息")
  130. # 设置Agent需要主动发起对话
  131. agent = service._get_agent_instance("user_id_0")
  132. agent.should_initiate_conversation = Mock(return_value=(True, MagicMock()))
  133. service._check_initiative_conversations()
  134. # 验证主动发起的消息
  135. sent_msg = queues.send_queue.consume()
  136. assert sent_msg is not None
  137. assert "主动发起" in sent_msg["text"]