Преглед на файлове

Introduce transaction mechanism in dialogue manager state change

StrayWarrior преди 4 дни
родител
ревизия
85188e15e9
променени са 3 файла, в които са добавени 97 реда и са изтрити 44 реда
  1. 6 1
      agent_service.py
  2. 84 42
      dialogue_manager.py
  3. 7 1
      unit_test.py

+ 6 - 1
agent_service.py

@@ -190,7 +190,7 @@ class AgentService:
                     self._send_response(staff_id, user_id, resp, message_type)
             else:
                 logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
-            # 当前消息处理成功,持久化agent状态
+            # 当前消息处理成功,commit并持久化agent状态
             agent.persist_state()
         except Exception as e:
             agent.rollback_state()
@@ -239,6 +239,8 @@ class AgentService:
 
             if should_initiate:
                 logger.warning("user: {}, initiate conversation".format(user_id))
+                # FIXME:虽然需要主动唤起的用户同时发来消息的概率很低,但仍可能会有并发冲突 需要并入事件驱动框架
+                agent.do_state_change(DialogueState.GREETING)
                 try:
                     if agent.previous_state == DialogueState.INITIALIZED:
                         # 完全无交互历史的用户才使用此策略
@@ -251,7 +253,10 @@ class AgentService:
                         self._send_response(staff_id, user_id, resp, MessageType.TEXT)
                         if self.limit_initiative_conversation_rate:
                             time.sleep(random.randint(10,20))
+                    agent.persist_state()
                 except Exception as e:
+                    # FIXME:虽然需要主动唤起的用户同时发来消息的概率很低,但仍可能会有并发冲突
+                    agent.rollback_state()
                     logger.error("Error in active greeting: {}".format(e))
             else:
                 logger.debug("user: {}, do not initiate conversation".format(user_id))

+ 84 - 42
dialogue_manager.py

@@ -55,6 +55,16 @@ class TimeContext(Enum):
     def __init__(self, description):
         self.description = description
 
+class DialogueStateChangeType(int, Enum):
+    STATE = 0
+    INTERACTION_TIME = 1
+    DIALOGUE_HISTORY = 2
+
+class DialogueStateChange:
+    def __init__(self, event_type: DialogueStateChangeType,old: Any, new: Any):
+        self.event_type = event_type
+        self.old = old
+        self.new = new
 
 class DialogueStateCache:
     def __init__(self):
@@ -95,13 +105,11 @@ class DialogueManager:
         self.state_cache = state_cache
         self.current_state = DialogueState.GREETING
         self.previous_state = DialogueState.INITIALIZED
-        # 用于消息处理失败时回滚
-        self.state_backup = (DialogueState.INITIALIZED, DialogueState.INITIALIZED)
         # 目前实际仅用作调试,拼装prompt时使用history_dialogue_service获取
         self.dialogue_history = []
         self.user_profile = self.user_manager.get_user_profile(user_id)
         self.staff_profile = self.user_manager.get_staff_profile(staff_id)
-        # FIXME(zhoutian): last_interaction_time也需要回滚
+        # FIXME: 交互时间和对话记录都涉及到回滚
         self.last_interaction_time = 0
         self.consecutive_clarifications = 0
         self.complex_request_counter = 0
@@ -113,6 +121,26 @@ class DialogueManager:
             config['storage']['history_dialogue']['api_base_url']
         )
         self._recover_state()
+        # 由于本地状态管理过于复杂,引入事务机制做状态回滚
+        self._uncommited_state_change = []
+
+    @staticmethod
+    def get_time_context(current_hour=None) -> TimeContext:
+        """获取当前时间上下文"""
+        if not current_hour:
+            current_hour = datetime.now().hour
+        if 5 <= current_hour < 8:
+            return TimeContext.EARLY_MORNING
+        elif 8 <= current_hour < 12:
+            return TimeContext.MORNING
+        elif 12 <= current_hour < 14:
+            return TimeContext.NOON
+        elif 14 <= current_hour < 18:
+            return TimeContext.AFTERNOON
+        elif 18 <= current_hour < 22:
+            return TimeContext.EVENING
+        else:
+            return TimeContext.NIGHT
 
     def _recover_state(self):
         self.current_state, self.previous_state = self.state_cache.get_state(self.staff_id, self.user_id)
@@ -127,45 +155,59 @@ class DialogueManager:
         time_for_read = datetime.fromtimestamp(self.last_interaction_time / 1000).strftime("%Y-%m-%d %H:%M:%S")
         logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: state: {self.current_state.name}, last_interaction: {time_for_read}")
 
+    def update_interaction_time(self, timestamp_ms: int):
+        self._uncommited_state_change.append(DialogueStateChange(
+            DialogueStateChangeType.INTERACTION_TIME,
+            self.last_interaction_time,
+            timestamp_ms
+        ))
+        self.last_interaction_time = timestamp_ms
+
+    def append_dialogue_history(self, message: Dict):
+        self._uncommited_state_change.append(DialogueStateChange(
+            DialogueStateChangeType.DIALOGUE_HISTORY,
+            None,
+            1
+        ))
+        self.dialogue_history.append(message)
+
     def persist_state(self):
         """持久化对话状态,只有当前状态处理成功后才应该做持久化"""
+        self.commit()
         config = configs.get()
         if config.get('debug_flags', {}).get('disable_database_write', False):
             return
         self.state_cache.set_state(self.staff_id, self.user_id, self.current_state, self.previous_state)
 
     def rollback_state(self):
-        logger.debug("staff[{}], user[{}]: rollback state: {}, previous state: {}".format(
-            self.staff_id, self.user_id, self.state_backup, self.current_state
-        ))
-        self.current_state, self.previous_state = self.state_backup
+        logger.info(f"staff[{self.staff_id}], user[{self.user_id}]: reverse state")
+        for entry in reversed(self._uncommited_state_change):
+            if entry.event_type == DialogueStateChangeType.STATE:
+                self.current_state, self.previous_state = entry.old
+            elif entry.event_type == DialogueStateChangeType.INTERACTION_TIME:
+                self.last_interaction_time = entry.old
+            elif entry.event_type == DialogueStateChangeType.DIALOGUE_HISTORY:
+                self.dialogue_history.pop()
+            else:
+                logger.error(f"unimplemented type: [{entry.event_type}]")
+        self._uncommited_state_change.clear()
 
-    @staticmethod
-    def get_time_context(current_hour=None) -> TimeContext:
-        """获取当前时间上下文"""
-        if not current_hour:
-            current_hour = datetime.now().hour
-        if 5 <= current_hour < 8:
-            return TimeContext.EARLY_MORNING
-        elif 8 <= current_hour < 12:
-            return TimeContext.MORNING
-        elif 12 <= current_hour < 14:
-            return TimeContext.NOON
-        elif 14 <= current_hour < 18:
-            return TimeContext.AFTERNOON
-        elif 18 <= current_hour < 22:
-            return TimeContext.EVENING
-        else:
-            return TimeContext.NIGHT
+    def commit(self):
+        self._uncommited_state_change.clear()
 
     def do_state_change(self, state: DialogueState):
-        self.state_backup = (self.current_state, self.previous_state)
+        state_backup = (self.current_state, self.previous_state)
         if self.current_state == DialogueState.MESSAGE_AGGREGATING:
             # MESSAGE_AGGREGATING不能成为previous_state,仅使用state_backup做回退
             self.current_state = state
         else:
             self.previous_state = self.current_state
             self.current_state = state
+        self._uncommited_state_change.append(DialogueStateChange(
+            DialogueStateChangeType.STATE,
+            state_backup,
+            (self.current_state, self.previous_state)
+        ))
 
     def update_state(self, message: Message) -> Tuple[bool, Optional[str]]:
         """根据用户消息更新对话状态,并返回是否需要发起回复 及下一条需处理的用户消息"""
@@ -174,7 +216,7 @@ class DialogueManager:
         # 如果当前已经是人工介入状态,保持该状态
         if self.current_state == DialogueState.HUMAN_INTERVENTION:
             # 记录对话历史,但不改变状态
-            self.dialogue_history.append({
+            self.append_dialogue_history({
                 "role": "user",
                 "content": message_text,
                 "timestamp": int(time.time() * 1000),
@@ -185,15 +227,17 @@ class DialogueManager:
         # 检查是否处于消息聚合状态
         if self.current_state == DialogueState.MESSAGE_AGGREGATING:
             # 收到的是特殊定时触发的空消息,且在聚合中,且已经超时,继续处理
-            if message.type == MessageType.AGGREGATION_TRIGGER \
-                    and message_ts - self.last_interaction_time > self.message_aggregation_sec * 1000:
-                logger.debug("user_id: {}, last interaction time: {}".format(
-                    self.user_id, datetime.fromtimestamp(self.last_interaction_time / 1000)))
+            if message.type == MessageType.AGGREGATION_TRIGGER:
+                if message_ts - self.last_interaction_time > self.message_aggregation_sec * 1000:
+                    logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: exit aggregation waiting")
+                else:
+                    logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: continue aggregation waiting")
+                    return False, message_text
             else:
                 # 非空消息,更新最后交互时间,保持消息聚合状态
                 if message_text:
                     self.unprocessed_messages.append(message_text)
-                    self.last_interaction_time = message_ts
+                    self.update_interaction_time(message_ts)
                 return False, message_text
         else:
             if message.type == MessageType.AGGREGATION_TRIGGER:
@@ -206,7 +250,7 @@ class DialogueManager:
                 self.unprocessed_messages.append(message_text)
                 # 更新最后交互时间
                 if message_text:
-                    self.last_interaction_time = message_ts
+                    self.update_interaction_time(message_ts)
                 return False, message_text
 
         # 获得未处理的聚合消息,并清空未处理队列
@@ -233,8 +277,8 @@ class DialogueManager:
         self.do_state_change(new_state)
 
         if message_text:
-            self.last_interaction_time = message_ts
-            self.dialogue_history.append({
+            self.update_interaction_time(message_ts)
+            self.append_dialogue_history({
                 "role": "user",
                 "content": message_text,
                 "timestamp": message_ts,
@@ -325,13 +369,13 @@ class DialogueManager:
     def resume_from_human_intervention(self) -> None:
         """从人工介入状态恢复"""
         if self.current_state == DialogueState.HUMAN_INTERVENTION:
-            self.current_state = DialogueState.GREETING
+            self.do_state_change(DialogueState.CHITCHAT)
             self.human_intervention_triggered = False
             self.consecutive_clarifications = 0
             self.complex_request_counter = 0
 
             # 记录恢复事件
-            self.dialogue_history.append({
+            self.append_dialogue_history({
                 "role": "system",
                 "content": "已从人工介入状态恢复到自动对话",
                 "timestamp": int(time.time() * 1000),
@@ -345,14 +389,14 @@ class DialogueManager:
             return None
 
         # 记录响应到对话历史
-        current_ts = int(time.time() * 1000)
-        self.dialogue_history.append({
+        message_ts = int(time.time() * 1000)
+        self.append_dialogue_history({
             "role": "assistant",
             "content": llm_response,
-            "timestamp": current_ts,
+            "timestamp": message_ts,
             "state": self.current_state.name
         })
-        self.last_interaction_time = current_ts
+        self.update_interaction_time(message_ts)
 
         return llm_response
 
@@ -392,8 +436,6 @@ class DialogueManager:
             # 根据时间上下文决定主动交互的状态
         if time_context in [TimeContext.MORNING,
                             TimeContext.NOON, TimeContext.AFTERNOON]:
-            self.previous_state = self.current_state
-            self.current_state = DialogueState.GREETING
             return True
         return False
 

+ 7 - 1
unit_test.py

@@ -70,6 +70,7 @@ def test_agent_state_change(test_env):
     agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
     assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
     assert agent.previous_state == DialogueState.GREETING
+    agent.commit()
 
     agent.do_state_change(DialogueState.CHITCHAT)
     assert agent.current_state == DialogueState.CHITCHAT
@@ -89,11 +90,16 @@ def test_agent_state_change(test_env):
     assert agent.previous_state == DialogueState.CHITCHAT
 
     agent.do_state_change(DialogueState.MESSAGE_AGGREGATING)
+    agent.commit()
+
     agent.do_state_change(DialogueState.CHITCHAT)
-    assert agent.state_backup == (DialogueState.MESSAGE_AGGREGATING, DialogueState.CHITCHAT)
     agent.rollback_state()
     assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
 
+    agent.rollback_state()
+    # no state should be rollback
+    assert agent.current_state == DialogueState.MESSAGE_AGGREGATING
+
 
 def test_response_sanitization(test_env):
     case1 = '[2024-01-01 12:00:00] 你好'