Przeglądaj źródła

Update dialogue_manager: return need response after update state

StrayWarrior 2 tygodni temu
rodzic
commit
c9af53d172
2 zmienionych plików z 46 dodań i 29 usunięć
  1. 6 4
      agent_service.py
  2. 40 25
      dialogue_manager.py

+ 6 - 4
agent_service.py

@@ -125,22 +125,24 @@ class AgentService:
 
         # 更新对话状态
         logging.debug("process message: {}".format(message))
-        dialogue_state, message_text = agent.update_state(message)
-        logging.debug("user: {}, next state: {}".format(user_id, dialogue_state))
+        need_response, message_text = agent.update_state(message)
+        logging.debug("user: {}, next state: {}".format(user_id, agent.current_state))
 
         # 根据状态路由消息
         if agent.is_in_human_intervention():
             self._route_to_human_intervention(user_id, message)
-        elif dialogue_state == DialogueState.MESSAGE_AGGREGATING:
+        elif agent.current_state == DialogueState.MESSAGE_AGGREGATING:
             if message.type != MessageType.AGGREGATION_TRIGGER:
                 # 产生一个触发器,但是不能由触发器递归产生
                 logging.debug("user: {}, waiting next message for aggregation".format(user_id))
                 self._schedule_aggregation_trigger(staff_id, user_id, agent.message_aggregation_sec)
             return
-        else:
+        elif need_response:
             # 先更新用户画像再处理回复
             self._update_user_profile(user_id, user_profile, message_text)
             self._get_chat_response(user_id, agent, message_text)
+        else:
+            logging.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
 
     def _route_to_human_intervention(self, user_id: str, origin_message: Message):
         """路由到人工干预"""

+ 40 - 25
dialogue_manager.py

@@ -121,6 +121,12 @@ class DialogueManager:
         else:
             # 默认设置为24小时前
             self.last_interaction_time = int(time.time() * 1000) - 24 * 3600 * 1000
+        time_for_read = datetime.fromtimestamp(self.last_interaction_time / 1000).strftime("%Y-%m-%d %H:%M:%S")
+        logging.debug(f"staff[{self.staff_id}], user[{self.user_id}]: state: {self.current_state.name}, last_interaction: {time_for_read}")
+
+    def persist_state(self):
+        """持久化对话状态"""
+        self.state_cache.set_state(self.staff_id, self.user_id, self.current_state, self.previous_state)
 
     @staticmethod
     def get_current_time_context() -> TimeContext:
@@ -139,8 +145,8 @@ class DialogueManager:
         else:
             return TimeContext.NIGHT
 
-    def update_state(self, message: Message) -> Tuple[DialogueState, Optional[str]]:
-        """根据用户消息更新对话状态,并返回下一条需处理的用户消息"""
+    def update_state(self, message: Message) -> Tuple[bool, Optional[str]]:
+        """根据用户消息更新对话状态,并返回是否需要发起回复 及下一条需处理的用户消息"""
         message_text = message.content
         message_ts = message.sendTime
         # 如果当前已经是人工介入状态,保持该状态
@@ -152,7 +158,7 @@ class DialogueManager:
                 "timestamp": int(time.time() * 1000),
                 "state": self.current_state.name
             })
-            return self.current_state, message_text
+            return False, message_text
 
         # 检查是否处于消息聚合状态
         if self.current_state == DialogueState.MESSAGE_AGGREGATING:
@@ -167,16 +173,21 @@ class DialogueManager:
                 if message_text:
                     self.unprocessed_messages.append(message_text)
                     self.last_interaction_time = message_ts
-                return self.current_state, message_text
-        elif message.type != MessageType.AGGREGATION_TRIGGER and self.message_aggregation_sec > 0:
-            # 收到有内容的用户消息,切换到消息聚合状态
-            self.previous_state = self.current_state
-            self.current_state = DialogueState.MESSAGE_AGGREGATING
-            self.unprocessed_messages.append(message_text)
-            # 更新最后交互时间
-            if message_text:
-                self.last_interaction_time = message_ts
-            return self.current_state, message_text
+                return False, message_text
+        else:
+            if message.type == MessageType.AGGREGATION_TRIGGER:
+                # 未在聚合状态中,收到的聚合触发消息为过时消息,不应当处理
+                return False, None
+            if message.type != MessageType.AGGREGATION_TRIGGER and self.message_aggregation_sec > 0:
+                # 收到有内容的用户消息,切换到消息聚合状态
+                self.previous_state = self.current_state
+                self.current_state = DialogueState.MESSAGE_AGGREGATING
+                self.unprocessed_messages.append(message_text)
+                # 更新最后交互时间
+                if message_text:
+                    self.last_interaction_time = message_ts
+                self.persist_state()
+                return False, message_text
 
         # 保存前一个状态
         self.previous_state = self.current_state
@@ -209,7 +220,7 @@ class DialogueManager:
 
         # 更新状态并持久化
         self.current_state = new_state
-        self.state_cache.set_state(self.staff_id, self.user_id, self.current_state, self.previous_state)
+        self.persist_state()
 
         # 更新最后交互时间
         if message_text:
@@ -224,7 +235,7 @@ class DialogueManager:
                 "state": self.current_state.name
             })
 
-        return self.current_state, message_text
+        return True, message_text
 
     def _determine_state_from_message(self, message_text: Optional[str]) -> DialogueState:
         """根据消息内容确定对话状态"""
@@ -367,14 +378,16 @@ class DialogueManager:
 
         threshold = thresholds.get(interaction_frequency, 12)
 
-        # 如果足够时间已经过去
-        if hours_passed >= threshold:
+        if hours_passed < threshold:
+            return False
             # 根据时间上下文决定主动交互的状态
-            if time_context in [TimeContext.EARLY_MORNING, TimeContext.MORNING,
-                                TimeContext.NOON, TimeContext.AFTERNOON,
-                                TimeContext.EVENING]:
-                return True
-
+        if time_context in [TimeContext.EARLY_MORNING, TimeContext.MORNING,
+                            TimeContext.NOON, TimeContext.AFTERNOON,
+                            TimeContext.EVENING]:
+            self.previous_state = self.current_state
+            self.current_state = DialogueState.GREETING
+            self.persist_state()
+            return True
         return False
 
     def is_in_human_intervention(self) -> bool:
@@ -415,13 +428,15 @@ class DialogueManager:
         state_to_prompt_map = {
             DialogueState.GREETING: GENERAL_GREETING_PROMPT,
             DialogueState.CHITCHAT: GENERAL_GREETING_PROMPT,
+            DialogueState.FAREWELL: GENERAL_GREETING_PROMPT
         }
         return state_to_prompt_map[state]
 
     def _select_coze_bot(self, state):
         state_to_bot_map = {
             DialogueState.GREETING: '7479005417885417487',
-            DialogueState.CHITCHAT: '7479005417885417487'
+            DialogueState.CHITCHAT: '7479005417885417487',
+            DialogueState.FAREWELL: '7479005417885417487'
         }
         return state_to_bot_map[state]
 
@@ -476,8 +491,8 @@ class DialogueManager:
             custom_variables.pop('user_profile', None)
             config['custom_variables'] = custom_variables
             config['bot_id'] = self._select_coze_bot(self.current_state)
-            #FIXME(zhoutian): 这种方法并不可靠,需要通过状态来判断
-            if not user_message:
+            #FIXME(zhoutian): 这种方法并不可靠,需要结合状态来判断
+            if self.current_state == DialogueState.GREETING and not messages:
                 messages.append(cozepy.Message.build_user_question_text('请开始对话'))
             #FIXME(zhoutian): 临时报警
             if user_message and not messages: