unit_test.py 8.6 KB


  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import pytest
  5. from unittest.mock import Mock, MagicMock
  6. import pqai_agent.abtest.client
  7. from pqai_agent.agent_service import AgentService
  8. from pqai_agent.dialogue_manager import DialogueState, TimeContext
  9. from pqai_agent.message_queue_backend import MemoryQueueBackend
  10. from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
  11. from pqai_agent.response_type_detector import ResponseTypeDetector
  12. from pqai_agent.user_manager import LocalUserManager
  13. import time
  14. import logging
  15. class TestMessageQueues:
  16. """测试用消息队列实现"""
  17. def __init__(self, receive_queue, send_queue, human_queue):
  18. self.receive_queue = receive_queue
  19. self.send_queue = send_queue
  20. self.human_queue = human_queue
  21. @pytest.fixture
  22. def test_env():
  23. """测试环境初始化"""
  24. logging.getLogger().setLevel(logging.DEBUG)
  25. user_manager = LocalUserManager()
  26. user_relation_manager = Mock()
  27. user_relation_manager.get_user_tags = Mock(return_value=['AgentTest1'])
  28. user_relation_manager.list_staff_users = Mock(return_value=[{'staff_id': 'staff_id_0', 'user_id': 'user_id_0'}])
  29. receive_queue = MemoryQueueBackend()
  30. send_queue = MemoryQueueBackend()
  31. human_queue = MemoryQueueBackend()
  32. queues = TestMessageQueues(receive_queue, send_queue, human_queue)
  33. # 创建Agent服务实例
  34. service = AgentService(
  35. receive_backend=receive_queue,
  36. send_backend=send_queue,
  37. human_backend=human_queue,
  38. user_manager=user_manager,
  39. user_relation_manager=user_relation_manager
  40. )
  41. service.user_profile_extractor.extract_profile_info = Mock(return_value=None)
  42. service.can_send_to_user = Mock(return_value=True)
  43. service.start()
  44. # 替换LLM调用为模拟响应
  45. service._call_chat_api = Mock(return_value="模拟响应")
  46. yield service, queues
  47. service.shutdown(sync=True)
  48. pqai_agent.abtest.client.get_client().shutdown(blocking=True)
  49. def test_agent_state_change(test_env):
  50. service, _ = test_env
  51. agent = service.get_agent_instance('staff_id_0', 'user_id_0')
  52. assert agent.current_state == DialogueState.INITIALIZED
  53. assert agent.previous_state == DialogueState.INITIALIZED
  54. agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
  55. assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
  56. assert agent.previous_state == DialogueState.INITIALIZED
  57. agent.do_state_change(DialogueState.GREETING)
  58. assert agent.current_state == DialogueState.GREETING
  59. assert agent.previous_state == DialogueState.INITIALIZED
  60. agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
  61. assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
  62. assert agent.previous_state == DialogueState.GREETING
  63. agent.commit()
  64. agent.do_state_change(DialogueState.CHITCHAT)
  65. assert agent.current_state == DialogueState.CHITCHAT
  66. assert agent.previous_state == DialogueState.GREETING
  67. agent.rollback_state()
  68. assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
  69. assert agent.previous_state == DialogueState.GREETING
  70. agent.do_state_change(DialogueState.CHITCHAT)
  71. assert agent.current_state == DialogueState.CHITCHAT
  72. assert agent.previous_state == DialogueState.GREETING
  73. agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
  74. agent.do_state_change(DialogueState.CHITCHAT)
  75. assert agent.current_state == DialogueState.CHITCHAT
  76. assert agent.previous_state == DialogueState.CHITCHAT
  77. agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
  78. agent.commit()
  79. agent.do_state_change(DialogueState.CHITCHAT)
  80. agent.rollback_state()
  81. assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
  82. agent.rollback_state()
  83. # no state should be rollback
  84. assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
  85. def test_response_sanitization(test_env):
  86. case1 = '[2024-01-01 12:00:00] 你好'
  87. ret1 = AgentService.sanitize_response(case1)
  88. assert ret1 == '你好'
  89. case1 = '2024-01-01 12:00:00 你好'
  90. ret2 = AgentService.sanitize_response(case1)
  91. assert ret2 == '你好'
  92. def test_normal_conversation_flow(test_env):
  93. """测试正常对话流程"""
  94. service, queues = test_env
  95. service.get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
  96. # 准备测试消息
  97. test_msg = MqMessage.build(
  98. MessageType.TEXT, MessageChannel.CORP_WECHAT,
  99. 'user_id_0', 'staff_id_0', '你好', int(time.time() * 1000))
  100. queues.receive_queue.produce(test_msg)
  101. # 处理消息
  102. message = service.receive_queue.consume()
  103. if message:
  104. service.process_single_message(message)
  105. # 验证响应消息
  106. sent_msg = queues.send_queue.consume()
  107. assert sent_msg is not None
  108. assert sent_msg.receiver == "user_id_0"
  109. assert "模拟响应" in sent_msg.content
  110. def test_aggregated_conversation_flow(test_env):
  111. """测试聚合对话流程"""
  112. service, queues = test_env
  113. service.get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 1
  114. # 准备测试消息
  115. ts_begin = int(time.time() * 1000)
  116. test_msg = MqMessage.build(
  117. MessageType.TEXT, MessageChannel.CORP_WECHAT,
  118. 'user_id_0', 'staff_id_0', '你好', ts_begin)
  119. queues.receive_queue.produce(test_msg)
  120. test_msg = MqMessage.build(
  121. MessageType.TEXT, MessageChannel.CORP_WECHAT,
  122. 'user_id_0', 'staff_id_0', '我是老李', ts_begin + 500)
  123. queues.receive_queue.produce(test_msg)
  124. # 处理消息
  125. message = service.receive_queue.consume()
  126. if message:
  127. service.process_single_message(message)
  128. # 验证第一次响应消息
  129. sent_msg = queues.send_queue.consume()
  130. assert sent_msg is None
  131. message = service.receive_queue.consume()
  132. if message:
  133. service.process_single_message(message)
  134. # 验证第二次响应消息
  135. sent_msg = queues.send_queue.consume()
  136. assert sent_msg is None
  137. # 模拟定时器产生空消息触发响应
  138. service.process_single_message(MqMessage.build(
  139. MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
  140. 'user_id_0', 'staff_id_0', None, ts_begin + 2000
  141. ))
  142. # 验证第三次响应消息
  143. sent_msg = queues.send_queue.consume()
  144. assert sent_msg is not None
  145. assert sent_msg.receiver == "user_id_0"
  146. assert "模拟响应" in sent_msg.content
  147. def test_human_intervention_trigger(test_env):
  148. """测试触发人工干预"""
  149. service, queues = test_env
  150. service.get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
  151. # 准备需要人工干预的消息
  152. test_msg = MqMessage.build(
  153. MessageType.TEXT, MessageChannel.CORP_WECHAT,
  154. "user_id_0", "staff_id_0",
  155. "我需要帮助!", int(time.time() * 1000)
  156. )
  157. queues.receive_queue.produce(test_msg)
  158. # 处理消息
  159. message = service.receive_queue.consume()
  160. if message:
  161. service.process_single_message(message)
  162. # 验证人工队列消息
  163. human_msg = queues.human_queue.consume()
  164. # 由于相关逻辑未启用,临时关闭该测试
  165. return
  166. assert human_msg is not None
  167. assert human_msg.sender == "user_id_0"
  168. assert "用户对话需人工介入" in human_msg.content
  169. def test_initiative_conversation(test_env):
  170. """测试主动发起对话"""
  171. service, queues = test_env
  172. service.get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
  173. service._call_chat_api = Mock(return_value="主动发起模拟消息")
  174. # 设置Agent需要主动发起对话
  175. agent = service.get_agent_instance('staff_id_0', "user_id_0")
  176. agent.should_initiate_conversation = Mock(return_value=(True, MagicMock()))
  177. # 发起对话有时间限制
  178. agent.get_time_context = Mock(return_value=TimeContext.MORNING)
  179. service._check_initiative_conversations()
  180. # 验证主动发起的消息 (由于当前有白名单,无法支持测试)
  181. sent_msg = queues.send_queue.consume()
  182. # assert sent_msg is not None
  183. # assert "主动发起" in sent_msg.content
  184. def test_response_type_detector(test_env):
  185. case1 = '大哥,那可得提前了解下天气,以便安排行程~我帮您查查明天北京天气?'
  186. assert ResponseTypeDetector.is_chinese_only(case1) == True
  187. assert ResponseTypeDetector.if_message_suitable_for_voice(case1) == True
  188. case2 = 'hi'
  189. assert ResponseTypeDetector.is_chinese_only(case2) == False
  190. case3 = '这是链接:http://domain.com'
  191. assert ResponseTypeDetector.is_chinese_only(case3) == False
  192. case4 = '大哥,那可得提前了解下天气'
  193. assert ResponseTypeDetector.if_message_suitable_for_voice(case4) == True