Bladeren bron

Update agent_service: add response sanitization

StrayWarrior 1 week geleden
bovenliggende
commit
27eafb8925
2 gewijzigde bestanden met toevoegingen van 19 en 3 verwijderingen
  1. 10 3
      agent_service.py
  2. 9 0
      unit_test.py

+ 10 - 3
agent_service.py

@@ -1,7 +1,7 @@
 #! /usr/bin/env python
 # -*- coding: utf-8 -*-
 # vim:fenc=utf-8
-
+import re
 import sys
 import time
 import random
@@ -178,8 +178,8 @@ class AgentService:
         """处理LLM响应"""
         chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
         logger.debug(chat_config)
-        # FIXME(zhoutian): 临时处理去除头尾的空格
-        chat_response = self._call_chat_api(chat_config).strip()
+        chat_response = self._call_chat_api(chat_config)
+        chat_response = self.sanitize_response(chat_response)
 
         if response := agent.generate_response(chat_response):
             logger.warning(f"staff[{agent.staff_id}] user[{user_id}]: response: {response}")
@@ -217,6 +217,13 @@ class AgentService:
             raise Exception('Unsupported chat service type: {}'.format(self.chat_service_type))
         return response
 
+    @staticmethod
+    def sanitize_response(response: str):
+        pattern = r'\[?\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\]?'
+        response = re.sub(pattern, '', response)
+        response = response.strip()
+        return response
+
 if __name__ == "__main__":
     config = configs.get()
     logging_service.setup_root_logger()

+ 9 - 0
unit_test.py

@@ -47,6 +47,15 @@ def test_env():
 
     return service, queues
 
+def test_response_sanitization(test_env):
+    case1 = '[2024-01-01 12:00:00] 你好'
+    ret1 = AgentService.sanitize_response(case1)
+    assert ret1 == '你好'
+
+    case1 = '2024-01-01 12:00:00 你好'
+    ret2 = AgentService.sanitize_response(case1)
+    assert ret2 == '你好'
+
 def test_normal_conversation_flow(test_env):
     """测试正常对话流程"""
     service, queues = test_env