unit_test.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import pytest
  5. from unittest.mock import Mock, MagicMock
  6. from pqai_agent.agent_service import AgentService, MemoryQueueBackend
  7. from pqai_agent.dialogue_manager import DialogueState, TimeContext
  8. from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
  9. from pqai_agent.response_type_detector import ResponseTypeDetector
  10. from pqai_agent.user_manager import LocalUserManager
  11. import time
  12. import logging
  13. class TestMessageQueues:
  14. """测试用消息队列实现"""
  15. def __init__(self, receive_queue, send_queue, human_queue):
  16. self.receive_queue = receive_queue
  17. self.send_queue = send_queue
  18. self.human_queue = human_queue
  19. @pytest.fixture
  20. def test_env():
  21. """测试环境初始化"""
  22. logging.getLogger().setLevel(logging.DEBUG)
  23. user_manager = LocalUserManager()
  24. user_relation_manager = Mock()
  25. user_relation_manager.get_user_tags = Mock(return_value=['AgentTest1'])
  26. user_relation_manager.list_staff_users = Mock(return_value=[{'staff_id': 'staff_id_0', 'user_id': 'user_id_0'}])
  27. receive_queue = MemoryQueueBackend()
  28. send_queue = MemoryQueueBackend()
  29. human_queue = MemoryQueueBackend()
  30. queues = TestMessageQueues(receive_queue, send_queue, human_queue)
  31. # 创建Agent服务实例
  32. service = AgentService(
  33. receive_backend=receive_queue,
  34. send_backend=send_queue,
  35. human_backend=human_queue,
  36. user_manager=user_manager,
  37. user_relation_manager=user_relation_manager
  38. )
  39. service.user_profile_extractor.extract_profile_info = Mock(return_value=None)
  40. # 替换LLM调用为模拟响应
  41. service._call_chat_api = Mock(return_value="模拟响应")
  42. return service, queues
  43. def test_agent_state_change(test_env):
  44. service, _ = test_env
  45. agent = service.get_agent_instance('staff_id_0', 'user_id_0')
  46. assert agent.current_state == DialogueState.INITIALIZED
  47. assert agent.previous_state == DialogueState.INITIALIZED
  48. agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
  49. assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
  50. assert agent.previous_state == DialogueState.INITIALIZED
  51. agent.do_state_change(DialogueState.GREETING)
  52. assert agent.current_state == DialogueState.GREETING
  53. assert agent.previous_state == DialogueState.INITIALIZED
  54. agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
  55. assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
  56. assert agent.previous_state == DialogueState.GREETING
  57. agent.commit()
  58. agent.do_state_change(DialogueState.CHITCHAT)
  59. assert agent.current_state == DialogueState.CHITCHAT
  60. assert agent.previous_state == DialogueState.GREETING
  61. agent.rollback_state()
  62. assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
  63. assert agent.previous_state == DialogueState.GREETING
  64. agent.do_state_change(DialogueState.CHITCHAT)
  65. assert agent.current_state == DialogueState.CHITCHAT
  66. assert agent.previous_state == DialogueState.GREETING
  67. agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
  68. agent.do_state_change(DialogueState.CHITCHAT)
  69. assert agent.current_state == DialogueState.CHITCHAT
  70. assert agent.previous_state == DialogueState.CHITCHAT
  71. agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
  72. agent.commit()
  73. agent.do_state_change(DialogueState.CHITCHAT)
  74. agent.rollback_state()
  75. assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
  76. agent.rollback_state()
  77. # no state should be rollback
  78. assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
  79. def test_response_sanitization(test_env):
  80. case1 = '[2024-01-01 12:00:00] 你好'
  81. ret1 = AgentService.sanitize_response(case1)
  82. assert ret1 == '你好'
  83. case1 = '2024-01-01 12:00:00 你好'
  84. ret2 = AgentService.sanitize_response(case1)
  85. assert ret2 == '你好'
  86. def test_normal_conversation_flow(test_env):
  87. """测试正常对话流程"""
  88. service, queues = test_env
  89. service.get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
  90. # 准备测试消息
  91. test_msg = MqMessage.build(
  92. MessageType.TEXT, MessageChannel.CORP_WECHAT,
  93. 'user_id_0', 'staff_id_0', '你好', int(time.time() * 1000))
  94. queues.receive_queue.produce(test_msg)
  95. # 处理消息
  96. message = service.receive_queue.consume()
  97. if message:
  98. service.process_single_message(message)
  99. # 验证响应消息
  100. sent_msg = queues.send_queue.consume()
  101. assert sent_msg is not None
  102. assert sent_msg.receiver == "user_id_0"
  103. assert "模拟响应" in sent_msg.content
  104. def test_aggregated_conversation_flow(test_env):
  105. """测试聚合对话流程"""
  106. service, queues = test_env
  107. service.get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 1
  108. # 准备测试消息
  109. ts_begin = int(time.time() * 1000)
  110. test_msg = MqMessage.build(
  111. MessageType.TEXT, MessageChannel.CORP_WECHAT,
  112. 'user_id_0', 'staff_id_0', '你好', ts_begin)
  113. queues.receive_queue.produce(test_msg)
  114. test_msg = MqMessage.build(
  115. MessageType.TEXT, MessageChannel.CORP_WECHAT,
  116. 'user_id_0', 'staff_id_0', '我是老李', ts_begin + 500)
  117. queues.receive_queue.produce(test_msg)
  118. # 处理消息
  119. message = service.receive_queue.consume()
  120. if message:
  121. service.process_single_message(message)
  122. # 验证第一次响应消息
  123. sent_msg = queues.send_queue.consume()
  124. assert sent_msg is None
  125. message = service.receive_queue.consume()
  126. if message:
  127. service.process_single_message(message)
  128. # 验证第二次响应消息
  129. sent_msg = queues.send_queue.consume()
  130. assert sent_msg is None
  131. # 模拟定时器产生空消息触发响应
  132. service.process_single_message(MqMessage.build(
  133. MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
  134. 'user_id_0', 'staff_id_0', None, ts_begin + 2000
  135. ))
  136. # 验证第三次响应消息
  137. sent_msg = queues.send_queue.consume()
  138. assert sent_msg is not None
  139. assert sent_msg.receiver == "user_id_0"
  140. assert "模拟响应" in sent_msg.content
  141. def test_human_intervention_trigger(test_env):
  142. """测试触发人工干预"""
  143. service, queues = test_env
  144. service.get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
  145. # 准备需要人工干预的消息
  146. test_msg = MqMessage.build(
  147. MessageType.TEXT, MessageChannel.CORP_WECHAT,
  148. "user_id_0", "staff_id_0",
  149. "我需要帮助!", int(time.time() * 1000)
  150. )
  151. queues.receive_queue.produce(test_msg)
  152. # 处理消息
  153. message = service.receive_queue.consume()
  154. if message:
  155. service.process_single_message(message)
  156. # 验证人工队列消息
  157. human_msg = queues.human_queue.consume()
  158. # 由于相关逻辑未启用,临时关闭该测试
  159. return
  160. assert human_msg is not None
  161. assert human_msg.sender == "user_id_0"
  162. assert "用户对话需人工介入" in human_msg.content
  163. def test_initiative_conversation(test_env):
  164. """测试主动发起对话"""
  165. service, queues = test_env
  166. service.get_agent_instance('staff_id_0', "user_id_0").message_aggregation_sec = 0
  167. service._call_chat_api = Mock(return_value="主动发起模拟消息")
  168. # 设置Agent需要主动发起对话
  169. agent = service.get_agent_instance('staff_id_0', "user_id_0")
  170. agent.should_initiate_conversation = Mock(return_value=(True, MagicMock()))
  171. # 发起对话有时间限制
  172. agent.get_time_context = Mock(return_value=TimeContext.MORNING)
  173. service._check_initiative_conversations()
  174. # 验证主动发起的消息 (由于当前有白名单,无法支持测试)
  175. sent_msg = queues.send_queue.consume()
  176. # assert sent_msg is not None
  177. # assert "主动发起" in sent_msg.content
  178. def test_response_type_detector(test_env):
  179. case1 = '大哥,那可得提前了解下天气,以便安排行程~我帮您查查明天北京天气?'
  180. assert ResponseTypeDetector.is_chinese_only(case1) == True
  181. assert ResponseTypeDetector.if_message_suitable_for_voice(case1) == False
  182. case2 = 'hi'
  183. assert ResponseTypeDetector.is_chinese_only(case2) == False
  184. case3 = '这是链接:http://domain.com'
  185. assert ResponseTypeDetector.is_chinese_only(case3) == False
  186. case4 = '大哥,那可得提前了解下天气'
  187. assert ResponseTypeDetector.if_message_suitable_for_voice(case4) == True