|
@@ -7,8 +7,13 @@ from typing import Dict, List, Optional, Tuple, Any
|
|
|
from datetime import datetime
|
|
|
import time
|
|
|
import logging
|
|
|
+
|
|
|
+import pymysql.cursors
|
|
|
+
|
|
|
import configs
|
|
|
import cozepy
|
|
|
+
|
|
|
+from database import MySQLManager
|
|
|
from history_dialogue_service import HistoryDialogueService
|
|
|
|
|
|
from chat_service import ChatServiceType
|
|
@@ -33,13 +38,14 @@ class DummyVectorMemoryManager:
|
|
|
return []
|
|
|
|
|
|
|
|
|
-class DialogueState(Enum):
|
|
|
- GREETING = auto() # 问候状态
|
|
|
- CHITCHAT = auto() # 闲聊状态
|
|
|
- CLARIFICATION = auto() # 澄清状态
|
|
|
- FAREWELL = auto() # 告别状态
|
|
|
- HUMAN_INTERVENTION = auto() # 人工介入状态
|
|
|
- MESSAGE_AGGREGATING = auto() # 等待消息状态
|
|
|
+class DialogueState(int, Enum):
|
|
|
+ INITIALIZED = 0
|
|
|
+ GREETING = 1 # 问候状态
|
|
|
+ CHITCHAT = 2 # 闲聊状态
|
|
|
+ CLARIFICATION = 3 # 澄清状态
|
|
|
+ FAREWELL = 4 # 告别状态
|
|
|
+ HUMAN_INTERVENTION = 5 # 人工介入状态
|
|
|
+ MESSAGE_AGGREGATING = 6 # 等待消息状态
|
|
|
|
|
|
|
|
|
class TimeContext(Enum):
|
|
@@ -53,15 +59,42 @@ class TimeContext(Enum):
|
|
|
def __init__(self, description):
|
|
|
self.description = description
|
|
|
|
|
|
+
|
|
|
+class DialogueStateCache:
|
|
|
+ def __init__(self):
|
|
|
+ config = configs.get()
|
|
|
+ self.db = MySQLManager(config['storage']['agent_state']['mysql'])
|
|
|
+ self.table = config['storage']['agent_state']['table']
|
|
|
+
|
|
|
+ def get_state(self, staff_id: str, user_id: str) -> Tuple[DialogueState, DialogueState]:
|
|
|
+ query = f"SELECT current_state, previous_state FROM {self.table} WHERE staff_id=%s AND user_id=%s"
|
|
|
+ data = self.db.select(query, pymysql.cursors.DictCursor, (staff_id, user_id))
|
|
|
+ if not data:
|
|
|
+ logging.warning(f"staff[{staff_id}], user[{user_id}]: agent state not found")
|
|
|
+ state = DialogueState.CHITCHAT
|
|
|
+ previous_state = DialogueState.INITIALIZED
|
|
|
+ self.set_state(staff_id, user_id, state, previous_state)
|
|
|
+ else:
|
|
|
+ state = DialogueState(data[0]['current_state'])
|
|
|
+ previous_state = DialogueState(data[0]['previous_state'])
|
|
|
+ return state, previous_state
|
|
|
+
|
|
|
+ def set_state(self, staff_id: str, user_id: str, state: DialogueState, previous_state: DialogueState):
|
|
|
+ query = f"INSERT INTO {self.table} (staff_id, user_id, current_state, previous_state)" \
|
|
|
+ f" VALUES (%s, %s, %s, %s) " \
|
|
|
+ f"ON DUPLICATE KEY UPDATE current_state=%s, previous_state=%s"
|
|
|
+ self.db.execute(query, (staff_id, user_id, state.value, previous_state.value, state.value, previous_state.value))
|
|
|
+
|
|
|
class DialogueManager:
|
|
|
- def __init__(self, staff_id: str, user_id: str, user_manager: UserManager):
|
|
|
+ def __init__(self, staff_id: str, user_id: str, user_manager: UserManager, state_cache: DialogueStateCache):
|
|
|
config = configs.get()
|
|
|
|
|
|
self.staff_id = staff_id
|
|
|
self.user_id = user_id
|
|
|
self.user_manager = user_manager
|
|
|
+ self.state_cache = state_cache
|
|
|
self.current_state = DialogueState.GREETING
|
|
|
- self.previous_state = None
|
|
|
+ self.previous_state = DialogueState.INITIALIZED
|
|
|
# 目前实际仅用作调试,拼装prompt时使用history_dialogue_service获取
|
|
|
self.dialogue_history = []
|
|
|
self.user_profile = self.user_manager.get_user_profile(user_id)
|
|
@@ -75,9 +108,12 @@ class DialogueManager:
|
|
|
self.history_dialogue_service = HistoryDialogueService(
|
|
|
config['storage']['history_dialogue']['api_base_url']
|
|
|
)
|
|
|
- self._reset_interaction_time()
|
|
|
+ self._recover_state()
|
|
|
+
|
|
|
+ def _recover_state(self):
|
|
|
+ self.current_state, self.previous_state = self.state_cache.get_state(self.staff_id, self.user_id)
|
|
|
|
|
|
- def _reset_interaction_time(self):
|
|
|
+ # 从数据库恢复对话状态
|
|
|
last_message = self.history_dialogue_service.get_dialogue_history(self.staff_id, self.user_id, 1)
|
|
|
if last_message:
|
|
|
self.last_interaction_time = last_message[0]['timestamp']
|
|
@@ -86,7 +122,7 @@ class DialogueManager:
|
|
|
self.last_interaction_time = int(time.time() * 1000) - 24 * 3600 * 1000
|
|
|
|
|
|
@staticmethod
|
|
|
- def get_current_time_context(self) -> TimeContext:
|
|
|
+ def get_current_time_context() -> TimeContext:
|
|
|
"""获取当前时间上下文"""
|
|
|
current_hour = datetime.now().hour
|
|
|
if 5 <= current_hour < 8:
|
|
@@ -170,8 +206,9 @@ class DialogueManager:
|
|
|
else:
|
|
|
self.consecutive_clarifications = 0
|
|
|
|
|
|
- # 更新状态
|
|
|
+ # 更新状态并持久化
|
|
|
self.current_state = new_state
|
|
|
+ self.state_cache.set_state(self.staff_id, self.user_id, self.current_state, self.previous_state)
|
|
|
|
|
|
# 更新最后交互时间
|
|
|
if message_text:
|
|
@@ -352,7 +389,7 @@ class DialogueManager:
|
|
|
context = {
|
|
|
"user_profile": self.user_profile,
|
|
|
"current_state": self.current_state.name,
|
|
|
- "previous_state": self.previous_state.name if self.previous_state else None,
|
|
|
+ "previous_state": self.previous_state.name,
|
|
|
"current_time_period": time_context.description,
|
|
|
# "dialogue_history": self.dialogue_history[-10:],
|
|
|
"last_interaction_interval": self._get_hours_since_last_interaction(2),
|