瀏覽代碼

Update dialogue_manager: add state cache

StrayWarrior 2 周之前
父節點
當前提交
52ec99c3fd
共有 2 個文件被更改,包括 53 次插入16 次删除
  1. 2 2
      database.py
  2. 51 14
      dialogue_manager.py

+ 2 - 2
database.py

@@ -11,13 +11,13 @@ class MySQLManager:
     def __init__(self, config):
         self.config = config
 
-    def select(self, sql, cursor_type=None):
+    def select(self, sql, cursor_type=None, args=None):
         """
         sql: SQL to execute, string
         """
         conn = pymysql.connect(**self.config)
         cursor = conn.cursor(cursor_type)
-        cursor.execute(sql)
+        cursor.execute(sql, args)
         data = cursor.fetchall()
         # do not handle exception
         cursor.close()

+ 51 - 14
dialogue_manager.py

@@ -7,8 +7,13 @@ from typing import Dict, List, Optional, Tuple, Any
 from datetime import datetime
 import time
 import logging
+
+import pymysql.cursors
+
 import configs
 import cozepy
+
+from database import MySQLManager
 from history_dialogue_service import HistoryDialogueService
 
 from chat_service import ChatServiceType
@@ -33,13 +38,14 @@ class DummyVectorMemoryManager:
         return []
 
 
-class DialogueState(Enum):
-    GREETING = auto()              # 问候状态
-    CHITCHAT = auto()              # 闲聊状态
-    CLARIFICATION = auto()         # 澄清状态
-    FAREWELL = auto()              # 告别状态
-    HUMAN_INTERVENTION = auto()    # 人工介入状态
-    MESSAGE_AGGREGATING = auto()   # 等待消息状态
+class DialogueState(int, Enum):
+    INITIALIZED = 0
+    GREETING = 1              # 问候状态
+    CHITCHAT = 2              # 闲聊状态
+    CLARIFICATION = 3         # 澄清状态
+    FAREWELL = 4              # 告别状态
+    HUMAN_INTERVENTION = 5    # 人工介入状态
+    MESSAGE_AGGREGATING = 6   # 等待消息状态
 
 
 class TimeContext(Enum):
@@ -53,15 +59,42 @@ class TimeContext(Enum):
     def __init__(self, description):
         self.description = description
 
+
+class DialogueStateCache:
+    def __init__(self):
+        config = configs.get()
+        self.db = MySQLManager(config['storage']['agent_state']['mysql'])
+        self.table = config['storage']['agent_state']['table']
+
+    def get_state(self, staff_id: str, user_id: str) -> Tuple[DialogueState, DialogueState]:
+        query = f"SELECT current_state, previous_state FROM {self.table} WHERE staff_id=%s AND user_id=%s"
+        data = self.db.select(query, pymysql.cursors.DictCursor, (staff_id, user_id))
+        if not data:
+            logging.warning(f"staff[{staff_id}], user[{user_id}]: agent state not found")
+            state = DialogueState.CHITCHAT
+            previous_state = DialogueState.INITIALIZED
+            self.set_state(staff_id, user_id, state, previous_state)
+        else:
+            state = DialogueState(data[0]['current_state'])
+            previous_state = DialogueState(data[0]['previous_state'])
+        return state, previous_state
+
+    def set_state(self, staff_id: str, user_id: str, state: DialogueState, previous_state: DialogueState):
+        query = f"INSERT INTO {self.table} (staff_id, user_id, current_state, previous_state)" \
+                f" VALUES (%s, %s, %s, %s) " \
+                f"ON DUPLICATE KEY UPDATE current_state=%s, previous_state=%s"
+        self.db.execute(query, (staff_id, user_id, state.value, previous_state.value, state.value, previous_state.value))
+
 class DialogueManager:
-    def __init__(self, staff_id: str, user_id: str, user_manager: UserManager):
+    def __init__(self, staff_id: str, user_id: str, user_manager: UserManager, state_cache: DialogueStateCache):
         config = configs.get()
 
         self.staff_id = staff_id
         self.user_id = user_id
         self.user_manager = user_manager
+        self.state_cache = state_cache
         self.current_state = DialogueState.GREETING
-        self.previous_state = None
+        self.previous_state = DialogueState.INITIALIZED
         # 目前实际仅用作调试,拼装prompt时使用history_dialogue_service获取
         self.dialogue_history = []
         self.user_profile = self.user_manager.get_user_profile(user_id)
@@ -75,9 +108,12 @@ class DialogueManager:
         self.history_dialogue_service = HistoryDialogueService(
             config['storage']['history_dialogue']['api_base_url']
         )
-        self._reset_interaction_time()
+        self._recover_state()
+
+    def _recover_state(self):
+        self.current_state, self.previous_state = self.state_cache.get_state(self.staff_id, self.user_id)
 
-    def _reset_interaction_time(self):
+        # 从数据库恢复对话状态
         last_message = self.history_dialogue_service.get_dialogue_history(self.staff_id, self.user_id, 1)
         if last_message:
             self.last_interaction_time = last_message[0]['timestamp']
@@ -86,7 +122,7 @@ class DialogueManager:
             self.last_interaction_time = int(time.time() * 1000) - 24 * 3600 * 1000
 
     @staticmethod
-    def get_current_time_context(self) -> TimeContext:
+    def get_current_time_context() -> TimeContext:
         """获取当前时间上下文"""
         current_hour = datetime.now().hour
         if 5 <= current_hour < 8:
@@ -170,8 +206,9 @@ class DialogueManager:
         else:
             self.consecutive_clarifications = 0
 
-        # 更新状态
+        # 更新状态并持久化
         self.current_state = new_state
+        self.state_cache.set_state(self.staff_id, self.user_id, self.current_state, self.previous_state)
 
         # 更新最后交互时间
         if message_text:
@@ -352,7 +389,7 @@ class DialogueManager:
         context = {
             "user_profile": self.user_profile,
             "current_state": self.current_state.name,
-            "previous_state": self.previous_state.name if self.previous_state else None,
+            "previous_state": self.previous_state.name,
             "current_time_period": time_context.description,
             # "dialogue_history": self.dialogue_history[-10:],
             "last_interaction_interval": self._get_hours_since_last_interaction(2),