Forráskód Böngészése

Update agent_service: use message type detector

StrayWarrior 3 hete
szülő
commit
717060d0b6
1 módosított fájl, 12 hozzáadás és 5 törlés
  1. 12 5
      agent_service.py

+ 12 - 5
agent_service.py

@@ -19,6 +19,7 @@ import logging_service
 from logging_service import logger
 from chat_service import CozeChat, ChatServiceType
 from dialogue_manager import DialogueManager, DialogueState, DialogueStateCache
+from response_type_detector import ResponseTypeDetector
 from user_manager import UserManager, LocalUserManager, MySQLUserManager, MySQLUserRelationManager, UserRelationManager
 from openai import OpenAI
 from message_queue_backend import MessageQueueBackend, MemoryQueueBackend, AliyunRocketMQQueueBackend
@@ -46,6 +47,7 @@ class AgentService:
         self.user_manager = user_manager
         self.user_relation_manager = user_relation_manager
         self.user_profile_extractor = UserProfileExtractor()
+        self.response_type_detector = ResponseTypeDetector()
         self.agent_registry: Dict[str, DialogueManager] = {}
 
         self.llm_client = OpenAI(
@@ -145,7 +147,12 @@ class AgentService:
                 self._update_user_profile(user_id, user_profile, agent.dialogue_history[-10:])
                 resp = self._get_chat_response(user_id, agent, message_text)
                 if resp:
-                    self._send_response(staff_id, user_id, resp)
+                    recent_dialogue = agent.dialogue_history[-10:]
+                    if len(recent_dialogue) < 2:
+                        message_type = MessageType.TEXT
+                    else:
+                        message_type = self.response_type_detector.detect_type(recent_dialogue[:-1], recent_dialogue[-1])
+                    self._send_response(staff_id, user_id, resp, message_type)
             else:
                 logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
             # 当前消息处理成功,持久化agent状态
@@ -154,8 +161,8 @@ class AgentService:
             agent.rollback_state()
             raise e
 
-    def _send_response(self, staff_id, user_id, response):
-        logger.warning(f"staff[{staff_id}] user[{user_id}]: response: {response}")
+    def _send_response(self, staff_id, user_id, response, message_type: MessageType):
+        logger.warning(f"staff[{staff_id}] user[{user_id}]: response[{message_type}] {response}")
         current_ts = int(time.time() * 1000)
         user_tags = self.user_relation_manager.get_user_tags(user_id)
         # FIXME(zhoutian)
@@ -166,7 +173,7 @@ class AgentService:
             logger.warning(f"staff[{staff_id}] user[{user_id}]: skip reply")
             return None
         self.send_queue.produce(
-            Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
+            Message.build(message_type, MessageChannel.CORP_WECHAT,
                           staff_id, user_id, response, current_ts)
         )
 
@@ -193,7 +200,7 @@ class AgentService:
                 logger.warning("user: {}, initiate conversation".format(user_id))
                 resp = self._get_chat_response(user_id, agent, None)
                 if resp:
-                    self._send_response(staff_id, user_id, resp)
+                    self._send_response(staff_id, user_id, resp, MessageType.TEXT)
                     time.sleep(random.randint(10,20))
             else:
                 logger.debug("user: {}, do not initiate conversation".format(user_id))