瀏覽代碼

Update dialogue_manager: rollback state if error in process

StrayWarrior 1 周之前
父節點
當前提交
b4f26f87d7
共有 2 個文件被更改,包括 42 次插入36 次删除
  1. 19 14
      agent_service.py
  2. 23 22
      dialogue_manager.py

+ 19 - 14
agent_service.py

@@ -132,20 +132,25 @@ class AgentService:
         logger.debug("user: {}, next state: {}".format(user_id, agent.current_state))
 
         # 根据状态路由消息
-        if agent.is_in_human_intervention():
-            self._route_to_human_intervention(user_id, message)
-        elif agent.current_state == DialogueState.MESSAGE_AGGREGATING:
-            if message.type != MessageType.AGGREGATION_TRIGGER:
-                # 产生一个触发器,但是不能由触发器递归产生
-                logger.debug("user: {}, waiting next message for aggregation".format(user_id))
-                self._schedule_aggregation_trigger(staff_id, user_id, agent.message_aggregation_sec)
-            return
-        elif need_response:
-            # 先更新用户画像再处理回复
-            self._update_user_profile(user_id, user_profile, message_text)
-            self._get_chat_response(user_id, agent, message_text)
-        else:
-            logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
+        try:
+            if agent.is_in_human_intervention():
+                self._route_to_human_intervention(user_id, message)
+            elif agent.current_state == DialogueState.MESSAGE_AGGREGATING:
+                if message.type != MessageType.AGGREGATION_TRIGGER:
+                    # 产生一个触发器,但是不能由触发器递归产生
+                    logger.debug("user: {}, waiting next message for aggregation".format(user_id))
+                    self._schedule_aggregation_trigger(staff_id, user_id, agent.message_aggregation_sec)
+            elif need_response:
+                # 先更新用户画像再处理回复
+                self._update_user_profile(user_id, user_profile, message_text)
+                self._get_chat_response(user_id, agent, message_text)
+            else:
+                logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
+            # 当前消息处理成功,持久化agent状态
+            agent.persist_state()
+        except Exception as e:
+            agent.rollback_state()
+            raise e
 
     def _route_to_human_intervention(self, user_id: str, origin_message: Message):
         """路由到人工干预"""

+ 23 - 22
dialogue_manager.py

@@ -91,6 +91,8 @@ class DialogueManager:
         self.state_cache = state_cache
         self.current_state = DialogueState.GREETING
         self.previous_state = DialogueState.INITIALIZED
+        # 用于消息处理失败时回滚
+        self.state_backup = (DialogueState.INITIALIZED, DialogueState.INITIALIZED)
         # 目前实际仅用作调试,拼装prompt时使用history_dialogue_service获取
         self.dialogue_history = []
         self.user_profile = self.user_manager.get_user_profile(user_id)
@@ -121,12 +123,18 @@ class DialogueManager:
         logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: state: {self.current_state.name}, last_interaction: {time_for_read}")
 
     def persist_state(self):
-        """持久化对话状态"""
+        """持久化对话状态,只有当前状态处理成功后才应该做持久化"""
         config = configs.get()
         if not config.get('debug_flags', {}).get('disable_state_persistence', False):
             return
         self.state_cache.set_state(self.staff_id, self.user_id, self.current_state, self.previous_state)
 
+    def rollback_state(self):
+        logger.debug("staff[{}], user[{}]: rollback state: {}, previous state: {}".format(
+            self.staff_id, self.user_id, self.state_backup, self.current_state
+        ))
+        self.current_state, self.previous_state = self.state_backup
+
     @staticmethod
     def get_time_context(current_hour=None) -> TimeContext:
         """获取当前时间上下文"""
@@ -145,6 +153,15 @@ class DialogueManager:
         else:
             return TimeContext.NIGHT
 
+    def do_state_change(self, state: DialogueState):
+        self.state_backup = (self.current_state, self.previous_state)
+        if self.current_state == DialogueState.MESSAGE_AGGREGATING:
+            # 不需要更新previous_state
+            self.current_state = state
+        else:
+            self.previous_state = self.current_state
+            self.current_state = state
+
     def update_state(self, message: Message) -> Tuple[bool, Optional[str]]:
         """根据用户消息更新对话状态,并返回是否需要发起回复 及下一条需处理的用户消息"""
         message_text = message.content
@@ -167,7 +184,7 @@ class DialogueManager:
                     and message_ts - self.last_interaction_time > self.message_aggregation_sec * 1000:
                 logger.debug("user_id: {}, last interaction time: {}".format(
                     self.user_id, datetime.fromtimestamp(self.last_interaction_time / 1000)))
-                self.current_state = self.previous_state
+                self.do_state_change(self.previous_state)
             else:
                 # 非空消息,更新最后交互时间,保持消息聚合状态
                 if message_text:
@@ -180,25 +197,13 @@ class DialogueManager:
                 return False, None
             if message.type != MessageType.AGGREGATION_TRIGGER and self.message_aggregation_sec > 0:
                 # 收到有内容的用户消息,切换到消息聚合状态
-                self.previous_state = self.current_state
-                self.current_state = DialogueState.MESSAGE_AGGREGATING
+                self.do_state_change(DialogueState.MESSAGE_AGGREGATING)
                 self.unprocessed_messages.append(message_text)
                 # 更新最后交互时间
                 if message_text:
                     self.last_interaction_time = message_ts
-                self.persist_state()
                 return False, message_text
 
-        # 保存前一个状态
-        self.previous_state = self.current_state
-
-        # 检查是否长时间未交互(超过3小时)
-        if self._get_hours_since_last_interaction() > 3:
-            self.current_state = DialogueState.GREETING
-            self.dialogue_history = []  # 重置对话历史
-            self.consecutive_clarifications = 0  # 重置澄清计数
-            self.complex_request_counter = 0  # 重置复杂请求计数
-
         # 获得未处理的聚合消息,并清空未处理队列
         if message_text:
             self.unprocessed_messages.append(message_text)
@@ -219,15 +224,12 @@ class DialogueManager:
         else:
             self.consecutive_clarifications = 0
 
-        # 更新状态并持久化
-        self.current_state = new_state
-        self.persist_state()
+        # 更新状态
+        self.do_state_change(new_state)
 
-        # 更新最后交互时间
         if message_text:
             self.last_interaction_time = message_ts
 
-        # 记录对话历史
         if message_text:
             self.dialogue_history.append({
                 "role": "user",
@@ -388,7 +390,6 @@ class DialogueManager:
                             TimeContext.EVENING]:
             self.previous_state = self.current_state
             self.current_state = DialogueState.GREETING
-            self.persist_state()
             return True
         return False
 
@@ -430,7 +431,7 @@ class DialogueManager:
     def _select_prompt(self, state):
         state_to_prompt_map = {
             DialogueState.GREETING: GENERAL_GREETING_PROMPT,
-            DialogueState.CHITCHAT: GENERAL_GREETING_PROMPT,
+            DialogueState.CHITCHAT: CHITCHAT_PROMPT_COZE,
             DialogueState.FAREWELL: GENERAL_GREETING_PROMPT
         }
         return state_to_prompt_map[state]