unit_test.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import pytest
  5. from datetime import datetime, timedelta
  6. from typing import Dict, Optional, Tuple, Any
  7. from unittest.mock import Mock, MagicMock
  8. from agent_service import AgentService, MemoryQueueBackend
  9. from dialogue_manager import DialogueState, TimeContext
  10. from message import MessageType, Message, MessageChannel
  11. from user_manager import LocalUserManager
  12. import time
  13. import logging
  14. class TestMessageQueues:
  15. """测试用消息队列实现"""
  16. def __init__(self, receive_queue, send_queue, human_queue):
  17. self.receive_queue = receive_queue
  18. self.send_queue = send_queue
  19. self.human_queue = human_queue
  20. @pytest.fixture
  21. def test_env():
  22. """测试环境初始化"""
  23. logging.getLogger().setLevel(logging.DEBUG)
  24. user_manager = LocalUserManager()
  25. user_relation_manager = Mock()
  26. user_relation_manager.get_user_tags = Mock(return_value=['AgentTest1'])
  27. user_relation_manager.list_staff_users = Mock(return_value=[{'staff_id': 'staff_id_0', 'user_id': 'user_id_0'}])
  28. receive_queue = MemoryQueueBackend()
  29. send_queue = MemoryQueueBackend()
  30. human_queue = MemoryQueueBackend()
  31. queues = TestMessageQueues(receive_queue, send_queue, human_queue)
  32. # 创建Agent服务实例
  33. service = AgentService(
  34. receive_backend=receive_queue,
  35. send_backend=send_queue,
  36. human_backend=human_queue,
  37. user_manager=user_manager,
  38. user_relation_manager=user_relation_manager
  39. )
  40. service.user_profile_extractor.extract_profile_info = Mock(return_value=None)
  41. service.limit_initiative_conversation_rate = False
  42. # 替换LLM调用为模拟响应
  43. service._call_chat_api = Mock(return_value="模拟响应")
  44. return service, queues
  45. def test_agent_state_change(test_env):
  46. service, _ = test_env
  47. agent = service._get_agent_instance('staff_id_0', 'user_id_0')
  48. assert agent.current_state == DialogueState.INITIALIZED
  49. assert agent.previous_state == DialogueState.INITIALIZED
  50. agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
  51. assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
  52. assert agent.previous_state == DialogueState.INITIALIZED
  53. agent.do_state_change(DialogueState.GREETING)
  54. assert agent.current_state == DialogueState.GREETING
  55. assert agent.previous_state == DialogueState.INITIALIZED
  56. agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
  57. assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
  58. assert agent.previous_state == DialogueState.GREETING
  59. agent.do_state_change(DialogueState.CHITCHAT)
  60. assert agent.current_state == DialogueState.CHITCHAT
  61. assert agent.previous_state == DialogueState.GREETING
  62. agent.rollback_state()
  63. assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
  64. assert agent.previous_state == DialogueState.GREETING
  65. agent.do_state_change(DialogueState.CHITCHAT)
  66. assert agent.current_state == DialogueState.CHITCHAT
  67. assert agent.previous_state == DialogueState.GREETING
  68. agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
  69. agent.do_state_change(DialogueState.CHITCHAT)
  70. assert agent.current_state == DialogueState.CHITCHAT
  71. assert agent.previous_state == DialogueState.CHITCHAT
  72. agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
  73. agent.do_state_change(DialogueState.CHITCHAT)
  74. assert agent.state_backup == (DialogueState.MESSAGE_AGGREGATING, DialogueState.CHITCHAT)
  75. agent.rollback_state()
  76. assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
  77. def test_response_sanitization(test_env):
  78. case1 = '[2024-01-01 12:00:00] 你好'
  79. ret1 = AgentService.sanitize_response(case1)
  80. assert ret1 == '你好'
  81. case1 = '2024-01-01 12:00:00 你好'
  82. ret2 = AgentService.sanitize_response(case1)
  83. assert ret2 == '你好'
  84. def test_normal_conversation_flow(test_env):
  85. """测试正常对话流程"""
  86. service, queues = test_env
  87. service._get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
  88. # 准备测试消息
  89. test_msg = Message.build(
  90. MessageType.TEXT, MessageChannel.CORP_WECHAT,
  91. 'user_id_0', 'staff_id_0', '你好', int(time.time() * 1000))
  92. queues.receive_queue.produce(test_msg)
  93. # 处理消息
  94. message = service.receive_queue.consume()
  95. if message:
  96. service.process_single_message(message)
  97. # 验证响应消息
  98. sent_msg = queues.send_queue.consume()
  99. assert sent_msg is not None
  100. assert sent_msg.receiver == "user_id_0"
  101. assert "模拟响应" in sent_msg.content
  102. def test_aggregated_conversation_flow(test_env):
  103. """测试聚合对话流程"""
  104. service, queues = test_env
  105. service._get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 1
  106. # 准备测试消息
  107. ts_begin = int(time.time() * 1000)
  108. test_msg = Message.build(
  109. MessageType.TEXT, MessageChannel.CORP_WECHAT,
  110. 'user_id_0', 'staff_id_0', '你好', ts_begin)
  111. queues.receive_queue.produce(test_msg)
  112. test_msg = Message.build(
  113. MessageType.TEXT, MessageChannel.CORP_WECHAT,
  114. 'user_id_0', 'staff_id_0', '我是老李', ts_begin + 500)
  115. queues.receive_queue.produce(test_msg)
  116. # 处理消息
  117. message = service.receive_queue.consume()
  118. if message:
  119. service.process_single_message(message)
  120. # 验证第一次响应消息
  121. sent_msg = queues.send_queue.consume()
  122. assert sent_msg is None
  123. message = service.receive_queue.consume()
  124. if message:
  125. service.process_single_message(message)
  126. # 验证第二次响应消息
  127. sent_msg = queues.send_queue.consume()
  128. assert sent_msg is None
  129. # 模拟定时器产生空消息触发响应
  130. service.process_single_message(Message.build(
  131. MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
  132. 'user_id_0', 'staff_id_0', None, ts_begin + 2000
  133. ))
  134. # 验证第三次响应消息
  135. sent_msg = queues.send_queue.consume()
  136. assert sent_msg is not None
  137. assert sent_msg.receiver == "user_id_0"
  138. assert "模拟响应" in sent_msg.content
  139. def test_human_intervention_trigger(test_env):
  140. """测试触发人工干预"""
  141. service, queues = test_env
  142. service._get_agent_instance('staff_id_0',"user_id_0").message_aggregation_sec = 0
  143. # 准备需要人工干预的消息
  144. test_msg = Message.build(
  145. MessageType.TEXT, MessageChannel.CORP_WECHAT,
  146. "user_id_0", "staff_id_0",
  147. "我需要帮助!", int(time.time() * 1000)
  148. )
  149. queues.receive_queue.produce(test_msg)
  150. # 处理消息
  151. message = service.receive_queue.consume()
  152. if message:
  153. service.process_single_message(message)
  154. # 验证人工队列消息
  155. human_msg = queues.human_queue.consume()
  156. # 由于相关逻辑未启用,临时关闭该测试
  157. return
  158. assert human_msg is not None
  159. assert human_msg.sender == "user_id_0"
  160. assert "用户对话需人工介入" in human_msg.content
  161. def test_initiative_conversation(test_env):
  162. """测试主动发起对话"""
  163. service, queues = test_env
  164. service._get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
  165. service._call_chat_api = Mock(return_value="主动发起模拟消息")
  166. # 设置Agent需要主动发起对话
  167. agent = service._get_agent_instance('staff_id_0', "user_id_0")
  168. # agent.should_initiate_conversation = Mock(return_value=(True, MagicMock()))
  169. # 发起对话有时间限制
  170. agent.get_time_context = Mock(return_value=TimeContext.MORNING)
  171. service._check_initiative_conversations()
  172. # 验证主动发起的消息
  173. sent_msg = queues.send_queue.consume()
  174. assert sent_msg is not None
  175. assert "主动发起" in sent_msg.content