浏览代码

Update dialogue_manager: fetch last push time from database

StrayWarrior 2 周之前
父节点
当前提交
c7b99edc6d
共有 3 个文件被更改,包括 38 次插入15 次删除
  1. 1 1
      pqai_agent/agent_service.py
  2. 33 12
      pqai_agent/dialogue_manager.py
  3. 4 2
      pqai_agent/push_service.py

+ 1 - 1
pqai_agent/agent_service.py

@@ -148,7 +148,7 @@ class AgentService:
         agent_key = 'agent_{}_{}'.format(staff_id, user_id)
         if agent_key not in self.agent_registry:
             self.agent_registry[agent_key] = DialogueManager(
-                staff_id, user_id, self.user_manager, self.agent_state_cache)
+                staff_id, user_id, self.user_manager, self.agent_state_cache, self.AgentDBSession)
         agent = self.agent_registry[agent_key]
         agent.refresh_profile()
         return agent

+ 33 - 12
pqai_agent/dialogue_manager.py

@@ -11,8 +11,10 @@ import textwrap
 import pymysql.cursors
 
 import cozepy
+from sqlalchemy.orm import sessionmaker, Session
 
 from pqai_agent import configs
+from pqai_agent.data_models.agent_push_record import AgentPushRecord
 from pqai_agent.logging_service import logger
 from pqai_agent.database import MySQLManager
 from pqai_agent import chat_service, prompt_templates
@@ -99,7 +101,8 @@ class DialogueStateCache:
                       .format(staff_id, user_id, state, previous_state, rows))
 
 class DialogueManager:
-    def __init__(self, staff_id: str, user_id: str, user_manager: UserManager, state_cache: DialogueStateCache):
+    def __init__(self, staff_id: str, user_id: str, user_manager: UserManager, state_cache: DialogueStateCache,
+                 AgentDBSession: sessionmaker[Session]):
         config = configs.get()
 
         self.staff_id = staff_id
@@ -113,7 +116,8 @@ class DialogueManager:
         self.user_profile = self.user_manager.get_user_profile(user_id)
         self.staff_profile = self.user_manager.get_staff_profile(staff_id)
         # FIXME: 交互时间和对话记录都涉及到回滚
-        self.last_interaction_time = 0
+        self.last_interaction_time_ms = 0
+        self.last_active_interaction_time_sec = 0
         self.human_intervention_triggered = False
         self.vector_memory = DummyVectorMemoryManager(user_id)
         self.message_aggregation_sec = config.get('agent_behavior', {}).get('message_aggregation_sec', 5)
@@ -121,6 +125,7 @@ class DialogueManager:
         self.history_dialogue_service = HistoryDialogueService(
             config['storage']['history_dialogue']['api_base_url']
         )
+        self.AgentDBSession = AgentDBSession
         self._recover_state()
         # 由于本地状态管理过于复杂,引入事务机制做状态回滚
         self._uncommited_state_change = []
@@ -159,7 +164,7 @@ class DialogueManager:
         self.dialogue_history = self.history_dialogue_service.get_dialogue_history(
             self.staff_id, self.user_id, minutes_to_get)
         if self.dialogue_history:
-            self.last_interaction_time = self.dialogue_history[-1]['timestamp']
+            self.last_interaction_time_ms = self.dialogue_history[-1]['timestamp']
             if self.current_state == DialogueState.MESSAGE_AGGREGATING:
                 # 需要恢复未处理对话,找到dialogue_history中最后未处理的user消息
                 for entry in reversed(self.dialogue_history):
@@ -168,17 +173,25 @@ class DialogueManager:
                         break
         else:
             # 默认设置
-            self.last_interaction_time = int(time.time() * 1000) - minutes_to_get * 60 * 1000
-        time_for_read = datetime.fromtimestamp(self.last_interaction_time / 1000).strftime("%Y-%m-%d %H:%M:%S")
-        logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: state: {self.current_state.name}, last_interaction: {time_for_read}")
+            self.last_interaction_time_ms = int(time.time() * 1000) - minutes_to_get * 60 * 1000
+        with self.AgentDBSession() as session:
+            # 读取数据库中的最后一次交互时间
+            query = session.query(AgentPushRecord).filter(
+                AgentPushRecord.staff_id == self.staff_id,
+                AgentPushRecord.user_id == self.user_id
+            ).order_by(AgentPushRecord.timestamp.desc()).first()
+            if query:
+                self.last_active_interaction_time_sec = query.timestamp
+        fmt_time = datetime.fromtimestamp(self.last_interaction_time_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
+        logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: state: {self.current_state.name}, last_interaction: {fmt_time}")
 
     def update_interaction_time(self, timestamp_ms: int):
         self._uncommited_state_change.append(DialogueStateChange(
             DialogueStateChangeType.INTERACTION_TIME,
-            self.last_interaction_time,
+            self.last_interaction_time_ms,
             timestamp_ms
         ))
-        self.last_interaction_time = timestamp_ms
+        self.last_interaction_time_ms = timestamp_ms
 
     def append_dialogue_history(self, message: Dict):
         self._uncommited_state_change.append(DialogueStateChange(
@@ -202,7 +215,7 @@ class DialogueManager:
             if entry.event_type == DialogueStateChangeType.STATE:
                 self.current_state, self.previous_state = entry.old
             elif entry.event_type == DialogueStateChangeType.INTERACTION_TIME:
-                self.last_interaction_time = entry.old
+                self.last_interaction_time_ms = entry.old
             elif entry.event_type == DialogueStateChangeType.DIALOGUE_HISTORY:
                 self.dialogue_history.pop()
             else:
@@ -255,7 +268,7 @@ class DialogueManager:
         if self.current_state == DialogueState.MESSAGE_AGGREGATING:
             # 收到的是特殊定时触发的空消息,且在聚合中,且已经超时,继续处理
             if message.type == MessageType.AGGREGATION_TRIGGER:
-                if message_ts - self.last_interaction_time > self.message_aggregation_sec * 1000:
+                if message_ts - self.last_interaction_time_ms > self.message_aggregation_sec * 1000:
                     logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: exit aggregation waiting")
                 else:
                     logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: continue aggregation waiting")
@@ -418,12 +431,16 @@ class DialogueManager:
         return llm_response
 
     def _get_hours_since_last_interaction(self, precision: int = -1):
-        time_diff = (time.time() * 1000) - self.last_interaction_time
+        time_diff = (time.time() * 1000) - self.last_interaction_time_ms
         hours_passed = time_diff / 1000 / 3600
         if precision >= 0:
             return round(hours_passed, precision)
         return hours_passed
 
+    def update_last_active_interaction_time(self, timestamp_sec: int):
+        # 只需更新本地时间,重启时可从数据库恢复
+        self.last_active_interaction_time_sec = timestamp_sec
+
     def should_initiate_conversation(self) -> bool:
         """判断是否应该主动发起对话"""
         # 如果处于人工介入状态,不应主动发起对话
@@ -446,7 +463,11 @@ class DialogueManager:
             "high": 12
         }
 
-        threshold = thresholds.get(interaction_frequency, 12)
+        threshold = thresholds.get(interaction_frequency, 24)
+        #FIXME 05-21 临时策略,两次主动发起至少48小时
+        if time.time() - self.last_active_interaction_time_sec < 2 * 24 * 3600:
+            logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: last active interaction time too short")
+            return False
 
         if hours_passed < threshold:
             return False

+ 4 - 2
pqai_agent/push_service.py

@@ -153,14 +153,16 @@ class PushTaskWorkerPool:
                     recent_dialogue, content, enable_random=True)
             response = agent.generate_response(content)
             if response:
+                current_ts = int(time.time())
                 with self.agent_service.AgentDBSession() as session:
                     msg_list = [{'type': MessageType.TEXT.value, 'content': response}]
                     record = AgentPushRecord(staff_id=staff_id, user_id=user_id,
                                              content=json.dumps(msg_list, ensure_ascii=False),
-                                             timestamp=int(datetime.now().timestamp()))
+                                             timestamp=current_ts)
                     session.add(record)
                     session.commit()
                 self.agent_service.send_response(staff_id, user_id, response, message_type, skip_check=True)
+                agent.update_last_active_interaction_time(current_ts)
             else:
                 logger.debug(f"staff[{staff_id}], user[{user_id}]: generate empty response")
             self.consumer.ack(msg)
@@ -178,7 +180,7 @@ class PushTaskWorkerPool:
             message_to_user = push_agent.generate_message(
                 context=main_agent.get_prompt_context(None),
                 dialogue_history=self.agent_service.history_dialogue_db.get_dialogue_history_backward(
-                    staff_id, user_id, main_agent.last_interaction_time, limit=100
+                    staff_id, user_id, main_agent.last_interaction_time_ms, limit=100
                 )
             )
             if message_to_user: