|
@@ -91,6 +91,8 @@ class DialogueManager:
|
|
|
self.state_cache = state_cache
|
|
|
self.current_state = DialogueState.GREETING
|
|
|
self.previous_state = DialogueState.INITIALIZED
|
|
|
+ # 用于消息处理失败时回滚
|
|
|
+ self.state_backup = (DialogueState.INITIALIZED, DialogueState.INITIALIZED)
|
|
|
# 目前实际仅用作调试,拼装prompt时使用history_dialogue_service获取
|
|
|
self.dialogue_history = []
|
|
|
self.user_profile = self.user_manager.get_user_profile(user_id)
|
|
@@ -121,12 +123,18 @@ class DialogueManager:
|
|
|
logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: state: {self.current_state.name}, last_interaction: {time_for_read}")
|
|
|
|
|
|
def persist_state(self):
|
|
|
- """持久化对话状态"""
|
|
|
+ """持久化对话状态,只有当前状态处理成功后才应该做持久化"""
|
|
|
config = configs.get()
|
|
|
if not config.get('debug_flags', {}).get('disable_state_persistence', False):
|
|
|
return
|
|
|
self.state_cache.set_state(self.staff_id, self.user_id, self.current_state, self.previous_state)
|
|
|
|
|
|
+ def rollback_state(self):
|
|
|
+ logger.debug("staff[{}], user[{}]: rollback state: {}, previous state: {}".format(
|
|
|
+ self.staff_id, self.user_id, self.state_backup, self.current_state
|
|
|
+ ))
|
|
|
+ self.current_state, self.previous_state = self.state_backup
|
|
|
+
|
|
|
@staticmethod
|
|
|
def get_time_context(current_hour=None) -> TimeContext:
|
|
|
"""获取当前时间上下文"""
|
|
@@ -145,6 +153,15 @@ class DialogueManager:
|
|
|
else:
|
|
|
return TimeContext.NIGHT
|
|
|
|
|
|
+ def do_state_change(self, state: DialogueState):
|
|
|
+ self.state_backup = (self.current_state, self.previous_state)
|
|
|
+ if self.current_state == DialogueState.MESSAGE_AGGREGATING:
|
|
|
+ # 不需要更新previous_state
|
|
|
+ self.current_state = state
|
|
|
+ else:
|
|
|
+ self.previous_state = self.current_state
|
|
|
+ self.current_state = state
|
|
|
+
|
|
|
def update_state(self, message: Message) -> Tuple[bool, Optional[str]]:
|
|
|
"""根据用户消息更新对话状态,并返回是否需要发起回复 及下一条需处理的用户消息"""
|
|
|
message_text = message.content
|
|
@@ -167,7 +184,7 @@ class DialogueManager:
|
|
|
and message_ts - self.last_interaction_time > self.message_aggregation_sec * 1000:
|
|
|
logger.debug("user_id: {}, last interaction time: {}".format(
|
|
|
self.user_id, datetime.fromtimestamp(self.last_interaction_time / 1000)))
|
|
|
- self.current_state = self.previous_state
|
|
|
+ self.do_state_change(self.previous_state)
|
|
|
else:
|
|
|
# 非空消息,更新最后交互时间,保持消息聚合状态
|
|
|
if message_text:
|
|
@@ -180,25 +197,13 @@ class DialogueManager:
|
|
|
return False, None
|
|
|
if message.type != MessageType.AGGREGATION_TRIGGER and self.message_aggregation_sec > 0:
|
|
|
# 收到有内容的用户消息,切换到消息聚合状态
|
|
|
- self.previous_state = self.current_state
|
|
|
- self.current_state = DialogueState.MESSAGE_AGGREGATING
|
|
|
+ self.do_state_change(DialogueState.MESSAGE_AGGREGATING)
|
|
|
self.unprocessed_messages.append(message_text)
|
|
|
# 更新最后交互时间
|
|
|
if message_text:
|
|
|
self.last_interaction_time = message_ts
|
|
|
- self.persist_state()
|
|
|
return False, message_text
|
|
|
|
|
|
- # 保存前一个状态
|
|
|
- self.previous_state = self.current_state
|
|
|
-
|
|
|
- # 检查是否长时间未交互(超过3小时)
|
|
|
- if self._get_hours_since_last_interaction() > 3:
|
|
|
- self.current_state = DialogueState.GREETING
|
|
|
- self.dialogue_history = [] # 重置对话历史
|
|
|
- self.consecutive_clarifications = 0 # 重置澄清计数
|
|
|
- self.complex_request_counter = 0 # 重置复杂请求计数
|
|
|
-
|
|
|
# 获得未处理的聚合消息,并清空未处理队列
|
|
|
if message_text:
|
|
|
self.unprocessed_messages.append(message_text)
|
|
@@ -219,15 +224,12 @@ class DialogueManager:
|
|
|
else:
|
|
|
self.consecutive_clarifications = 0
|
|
|
|
|
|
- # 更新状态并持久化
|
|
|
- self.current_state = new_state
|
|
|
- self.persist_state()
|
|
|
+ # 更新状态
|
|
|
+ self.do_state_change(new_state)
|
|
|
|
|
|
- # 更新最后交互时间
|
|
|
if message_text:
|
|
|
self.last_interaction_time = message_ts
|
|
|
|
|
|
- # 记录对话历史
|
|
|
if message_text:
|
|
|
self.dialogue_history.append({
|
|
|
"role": "user",
|
|
@@ -388,7 +390,6 @@ class DialogueManager:
|
|
|
TimeContext.EVENING]:
|
|
|
self.previous_state = self.current_state
|
|
|
self.current_state = DialogueState.GREETING
|
|
|
- self.persist_state()
|
|
|
return True
|
|
|
return False
|
|
|
|
|
@@ -430,7 +431,7 @@ class DialogueManager:
|
|
|
def _select_prompt(self, state):
|
|
|
state_to_prompt_map = {
|
|
|
DialogueState.GREETING: GENERAL_GREETING_PROMPT,
|
|
|
- DialogueState.CHITCHAT: GENERAL_GREETING_PROMPT,
|
|
|
+ DialogueState.CHITCHAT: CHITCHAT_PROMPT_COZE,
|
|
|
DialogueState.FAREWELL: GENERAL_GREETING_PROMPT
|
|
|
}
|
|
|
return state_to_prompt_map[state]
|