dialogue_manager.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. from enum import Enum, auto
  5. from typing import Dict, List, Optional, Tuple, Any
  6. from datetime import datetime
  7. import time
  8. import logging
  9. import pymysql.cursors
  10. import configs
  11. import cozepy
  12. from database import MySQLManager
  13. from history_dialogue_service import HistoryDialogueService
  14. from chat_service import ChatServiceType
  15. from message import MessageType, Message
  16. # from vector_memory_manager import VectorMemoryManager
  17. from structured_memory_manager import StructuredMemoryManager
  18. from user_manager import UserManager
  19. from prompt_templates import *
  20. # 配置日志
  21. logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(funcName)s[%(lineno)d] - %(levelname)s - %(message)s')
  22. logger = logging.getLogger(__name__)
  23. class DummyVectorMemoryManager:
  24. def __init__(self, user_id):
  25. pass
  26. def add_to_memory(self, conversation):
  27. pass
  28. def retrieve_relevant_memories(self, query, k=3):
  29. return []
  30. class DialogueState(int, Enum):
  31. INITIALIZED = 0
  32. GREETING = 1 # 问候状态
  33. CHITCHAT = 2 # 闲聊状态
  34. CLARIFICATION = 3 # 澄清状态
  35. FAREWELL = 4 # 告别状态
  36. HUMAN_INTERVENTION = 5 # 人工介入状态
  37. MESSAGE_AGGREGATING = 6 # 等待消息状态
  38. class TimeContext(Enum):
  39. EARLY_MORNING = "清晨" # 清晨 (5:00-7:59)
  40. MORNING = "上午" # 上午 (8:00-11:59)
  41. NOON = "中午" # 中午 (12:00-13:59)
  42. AFTERNOON = "下午" # 下午 (14:00-17:59)
  43. EVENING = "晚上" # 晚上 (18:00-21:59)
  44. NIGHT = "深夜" # 夜晚 (22:00-4:59)
  45. def __init__(self, description):
  46. self.description = description
  47. class DialogueStateCache:
  48. def __init__(self):
  49. config = configs.get()
  50. self.db = MySQLManager(config['storage']['agent_state']['mysql'])
  51. self.table = config['storage']['agent_state']['table']
  52. def get_state(self, staff_id: str, user_id: str) -> Tuple[DialogueState, DialogueState]:
  53. query = f"SELECT current_state, previous_state FROM {self.table} WHERE staff_id=%s AND user_id=%s"
  54. data = self.db.select(query, pymysql.cursors.DictCursor, (staff_id, user_id))
  55. if not data:
  56. logging.warning(f"staff[{staff_id}], user[{user_id}]: agent state not found")
  57. state = DialogueState.CHITCHAT
  58. previous_state = DialogueState.INITIALIZED
  59. self.set_state(staff_id, user_id, state, previous_state)
  60. else:
  61. state = DialogueState(data[0]['current_state'])
  62. previous_state = DialogueState(data[0]['previous_state'])
  63. return state, previous_state
  64. def set_state(self, staff_id: str, user_id: str, state: DialogueState, previous_state: DialogueState):
  65. query = f"INSERT INTO {self.table} (staff_id, user_id, current_state, previous_state)" \
  66. f" VALUES (%s, %s, %s, %s) " \
  67. f"ON DUPLICATE KEY UPDATE current_state=%s, previous_state=%s"
  68. rows = self.db.execute(query, (staff_id, user_id, state.value, previous_state.value, state.value, previous_state.value))
  69. logging.debug("staff[{}], user[{}]: set state: {}, previous state: {}, rows affected: {}"
  70. .format(staff_id, user_id, state, previous_state, rows))
  71. class DialogueManager:
  72. def __init__(self, staff_id: str, user_id: str, user_manager: UserManager, state_cache: DialogueStateCache):
  73. config = configs.get()
  74. self.staff_id = staff_id
  75. self.user_id = user_id
  76. self.user_manager = user_manager
  77. self.state_cache = state_cache
  78. self.current_state = DialogueState.GREETING
  79. self.previous_state = DialogueState.INITIALIZED
  80. # 目前实际仅用作调试,拼装prompt时使用history_dialogue_service获取
  81. self.dialogue_history = []
  82. self.user_profile = self.user_manager.get_user_profile(user_id)
  83. self.staff_profile = self.user_manager.get_staff_profile(staff_id)
  84. self.last_interaction_time = 0
  85. self.consecutive_clarifications = 0
  86. self.complex_request_counter = 0
  87. self.human_intervention_triggered = False
  88. self.vector_memory = DummyVectorMemoryManager(user_id)
  89. self.message_aggregation_sec = config.get('message_aggregation_sec', 5)
  90. self.unprocessed_messages = []
  91. self.history_dialogue_service = HistoryDialogueService(
  92. config['storage']['history_dialogue']['api_base_url']
  93. )
  94. self._recover_state()
  95. def _recover_state(self):
  96. self.current_state, self.previous_state = self.state_cache.get_state(self.staff_id, self.user_id)
  97. # 从数据库恢复对话状态
  98. last_message = self.history_dialogue_service.get_dialogue_history(self.staff_id, self.user_id, 1)
  99. if last_message:
  100. self.last_interaction_time = last_message[0]['timestamp']
  101. else:
  102. # 默认设置为24小时前
  103. self.last_interaction_time = int(time.time() * 1000) - 24 * 3600 * 1000
  104. time_for_read = datetime.fromtimestamp(self.last_interaction_time / 1000).strftime("%Y-%m-%d %H:%M:%S")
  105. logging.debug(f"staff[{self.staff_id}], user[{self.user_id}]: state: {self.current_state.name}, last_interaction: {time_for_read}")
  106. def persist_state(self):
  107. """持久化对话状态"""
  108. self.state_cache.set_state(self.staff_id, self.user_id, self.current_state, self.previous_state)
  109. @staticmethod
  110. def get_current_time_context() -> TimeContext:
  111. """获取当前时间上下文"""
  112. current_hour = datetime.now().hour
  113. if 5 <= current_hour < 8:
  114. return TimeContext.EARLY_MORNING
  115. elif 8 <= current_hour < 12:
  116. return TimeContext.MORNING
  117. elif 12 <= current_hour < 14:
  118. return TimeContext.NOON
  119. elif 14 <= current_hour < 18:
  120. return TimeContext.AFTERNOON
  121. elif 18 <= current_hour < 22:
  122. return TimeContext.EVENING
  123. else:
  124. return TimeContext.NIGHT
  125. def update_state(self, message: Message) -> Tuple[bool, Optional[str]]:
  126. """根据用户消息更新对话状态,并返回是否需要发起回复 及下一条需处理的用户消息"""
  127. message_text = message.content
  128. message_ts = message.sendTime
  129. # 如果当前已经是人工介入状态,保持该状态
  130. if self.current_state == DialogueState.HUMAN_INTERVENTION:
  131. # 记录对话历史,但不改变状态
  132. self.dialogue_history.append({
  133. "role": "user",
  134. "content": message_text,
  135. "timestamp": int(time.time() * 1000),
  136. "state": self.current_state.name
  137. })
  138. return False, message_text
  139. # 检查是否处于消息聚合状态
  140. if self.current_state == DialogueState.MESSAGE_AGGREGATING:
  141. # 收到的是特殊定时触发的空消息,且在聚合中,且已经超时,恢复之前状态,继续处理
  142. if message.type == MessageType.AGGREGATION_TRIGGER \
  143. and message_ts - self.last_interaction_time > self.message_aggregation_sec * 1000:
  144. logging.debug("user_id: {}, last interaction time: {}".format(
  145. self.user_id, datetime.fromtimestamp(self.last_interaction_time / 1000)))
  146. self.current_state = self.previous_state
  147. else:
  148. # 非空消息,更新最后交互时间,保持消息聚合状态
  149. if message_text:
  150. self.unprocessed_messages.append(message_text)
  151. self.last_interaction_time = message_ts
  152. return False, message_text
  153. else:
  154. if message.type == MessageType.AGGREGATION_TRIGGER:
  155. # 未在聚合状态中,收到的聚合触发消息为过时消息,不应当处理
  156. return False, None
  157. if message.type != MessageType.AGGREGATION_TRIGGER and self.message_aggregation_sec > 0:
  158. # 收到有内容的用户消息,切换到消息聚合状态
  159. self.previous_state = self.current_state
  160. self.current_state = DialogueState.MESSAGE_AGGREGATING
  161. self.unprocessed_messages.append(message_text)
  162. # 更新最后交互时间
  163. if message_text:
  164. self.last_interaction_time = message_ts
  165. self.persist_state()
  166. return False, message_text
  167. # 保存前一个状态
  168. self.previous_state = self.current_state
  169. # 检查是否长时间未交互(超过3小时)
  170. if self._get_hours_since_last_interaction() > 3:
  171. self.current_state = DialogueState.GREETING
  172. self.dialogue_history = [] # 重置对话历史
  173. self.consecutive_clarifications = 0 # 重置澄清计数
  174. self.complex_request_counter = 0 # 重置复杂请求计数
  175. # 获得未处理的聚合消息,并清空未处理队列
  176. if message_text:
  177. self.unprocessed_messages.append(message_text)
  178. if self.unprocessed_messages:
  179. message_text = '\n'.join(self.unprocessed_messages)
  180. self.unprocessed_messages.clear()
  181. # 根据消息内容和当前状态确定新状态
  182. new_state = self._determine_state_from_message(message_text)
  183. # 处理连续澄清的情况
  184. if new_state == DialogueState.CLARIFICATION:
  185. self.consecutive_clarifications += 1
  186. # FIXME(zhoutian): 规则过于简单
  187. if self.consecutive_clarifications >= 10000:
  188. new_state = DialogueState.HUMAN_INTERVENTION
  189. # self._trigger_human_intervention("连续多次澄清请求")
  190. else:
  191. self.consecutive_clarifications = 0
  192. # 更新状态并持久化
  193. self.current_state = new_state
  194. self.persist_state()
  195. # 更新最后交互时间
  196. if message_text:
  197. self.last_interaction_time = message_ts
  198. # 记录对话历史
  199. if message_text:
  200. self.dialogue_history.append({
  201. "role": "user",
  202. "content": message_text,
  203. "timestamp": message_ts,
  204. "state": self.current_state.name
  205. })
  206. return True, message_text
  207. def _determine_state_from_message(self, message_text: Optional[str]) -> DialogueState:
  208. """根据消息内容确定对话状态"""
  209. if not message_text:
  210. return self.current_state
  211. # 简单的规则-关键词匹配
  212. message_lower = message_text.lower()
  213. # 判断是否是复杂请求
  214. # FIXME(zhoutian): 规则过于简单
  215. # complex_request_keywords = ["帮我", "怎么办", "我需要", "麻烦你", "请帮助", "急", "紧急"]
  216. # if any(keyword in message_lower for keyword in complex_request_keywords):
  217. # self.complex_request_counter += 1
  218. #
  219. # # 如果检测到困难请求且计数达到阈值,触发人工介入
  220. # if self.complex_request_counter >= 1:
  221. # # self._trigger_human_intervention("检测到复杂请求")
  222. # return DialogueState.HUMAN_INTERVENTION
  223. # else:
  224. # # 如果不是复杂请求,重置计数器
  225. # self.complex_request_counter = 0
  226. # 问候检测
  227. greeting_keywords = ["你好", "早上好", "中午好", "晚上好", "嗨", "在吗"]
  228. if any(keyword in message_lower for keyword in greeting_keywords):
  229. return DialogueState.GREETING
  230. # 告别检测
  231. farewell_keywords = ["再见", "拜拜", "晚安", "明天见", "回头见"]
  232. if any(keyword in message_lower for keyword in farewell_keywords):
  233. return DialogueState.FAREWELL
  234. # 澄清请求
  235. clarification_keywords = ["没明白", "不明白", "没听懂", "不懂", "什么意思", "再说一遍"]
  236. if any(keyword in message_lower for keyword in clarification_keywords):
  237. return DialogueState.CLARIFICATION
  238. # 默认为闲聊状态
  239. return DialogueState.CHITCHAT
  240. def _trigger_human_intervention(self, reason: str) -> None:
  241. """触发人工介入"""
  242. if not self.human_intervention_triggered:
  243. self.human_intervention_triggered = True
  244. # 记录人工介入事件
  245. event = {
  246. "timestamp": int(time.time() * 1000),
  247. "reason": reason,
  248. "dialogue_context": self.history_dialogue_service.get_dialogue_history(self.staff_id, self.user_id, 5)
  249. }
  250. # 更新用户资料中的人工介入历史
  251. if "human_intervention_history" not in self.user_profile:
  252. self.user_profile["human_intervention_history"] = []
  253. self.user_profile["human_intervention_history"].append(event)
  254. self.user_manager.save_user_profile(self.user_id, self.user_profile)
  255. # 发送告警
  256. self._send_human_intervention_alert(reason)
  257. def _send_human_intervention_alert(self, reason: str) -> None:
  258. alert_message = f"""
  259. 人工介入告警
  260. 用户ID: {self.user_id}
  261. 用户昵称: {self.user_profile.get("nickname", "未知")}
  262. 时间: {int(time.time() * 1000)}
  263. 原因: {reason}
  264. 最近对话:
  265. """
  266. # 添加最近的对话记录
  267. recent_dialogues = self.history_dialogue_service.get_dialogue_history(self.staff_id, self.user_id, 5)
  268. for dialogue in recent_dialogues:
  269. alert_message += f"\n{dialogue['role']}: {dialogue['content']}"
  270. # TODO(zhoutian): 实现发送告警的具体逻辑
  271. logger.warning(alert_message)
  272. def resume_from_human_intervention(self) -> None:
  273. """从人工介入状态恢复"""
  274. if self.current_state == DialogueState.HUMAN_INTERVENTION:
  275. self.current_state = DialogueState.GREETING
  276. self.human_intervention_triggered = False
  277. self.consecutive_clarifications = 0
  278. self.complex_request_counter = 0
  279. # 记录恢复事件
  280. self.dialogue_history.append({
  281. "role": "system",
  282. "content": "已从人工介入状态恢复到自动对话",
  283. "timestamp": int(time.time() * 1000),
  284. "state": self.current_state.name
  285. })
  286. def generate_response(self, llm_response: str) -> Optional[str]:
  287. """根据当前状态处理LLM响应,如果处于人工介入状态则返回None"""
  288. # 如果处于人工介入状态,不生成回复
  289. if self.current_state == DialogueState.HUMAN_INTERVENTION:
  290. return None
  291. # 记录响应到对话历史
  292. current_ts = int(time.time() * 1000)
  293. self.dialogue_history.append({
  294. "role": "assistant",
  295. "content": llm_response,
  296. "timestamp": current_ts,
  297. "state": self.current_state.name
  298. })
  299. self.last_interaction_time = current_ts
  300. return llm_response
  301. def _get_hours_since_last_interaction(self, precision: int = -1):
  302. time_diff = (time.time() * 1000) - self.last_interaction_time
  303. hours_passed = time_diff / 1000 / 3600
  304. if precision >= 0:
  305. return round(hours_passed, precision)
  306. return hours_passed
  307. def should_initiate_conversation(self) -> bool:
  308. """判断是否应该主动发起对话"""
  309. # 如果处于人工介入状态,不应主动发起对话
  310. if self.current_state == DialogueState.HUMAN_INTERVENTION:
  311. return False
  312. hours_passed = self._get_hours_since_last_interaction()
  313. # 获取当前时间上下文
  314. time_context = self.get_current_time_context()
  315. # 根据用户交互频率偏好设置不同的阈值
  316. interaction_frequency = self.user_profile.get("interaction_frequency", "medium")
  317. # 设置不同偏好的交互时间阈值(小时)
  318. thresholds = {
  319. "low": 24, # 低频率:一天一次
  320. "medium": 12, # 中频率:半天一次
  321. "high": 6 # 高频率:大约6小时一次
  322. }
  323. threshold = thresholds.get(interaction_frequency, 12)
  324. if hours_passed < threshold:
  325. return False
  326. # 根据时间上下文决定主动交互的状态
  327. if time_context in [TimeContext.EARLY_MORNING, TimeContext.MORNING,
  328. TimeContext.NOON, TimeContext.AFTERNOON,
  329. TimeContext.EVENING]:
  330. self.previous_state = self.current_state
  331. self.current_state = DialogueState.GREETING
  332. self.persist_state()
  333. return True
  334. return False
  335. def is_in_human_intervention(self) -> bool:
  336. """检查是否处于人工介入状态"""
  337. return self.current_state == DialogueState.HUMAN_INTERVENTION
  338. def get_prompt_context(self, user_message) -> Dict:
  339. # 获取当前时间上下文
  340. time_context = self.get_current_time_context()
  341. # 刷新用户画像
  342. self.user_profile = self.user_manager.get_user_profile(self.user_id)
  343. # 刷新员工画像(不一定需要)
  344. self.staff_profile = self.user_manager.get_staff_profile(self.staff_id)
  345. context = {
  346. "user_profile": self.user_profile,
  347. "current_state": self.current_state.name,
  348. "previous_state": self.previous_state.name,
  349. "current_time_period": time_context.description,
  350. # "dialogue_history": self.dialogue_history[-10:],
  351. "last_interaction_interval": self._get_hours_since_last_interaction(2),
  352. "if_first_interaction": False,
  353. "if_active_greeting": False if user_message else True,
  354. **self.user_profile,
  355. **self.staff_profile
  356. }
  357. # 获取长期记忆
  358. relevant_memories = self.vector_memory.retrieve_relevant_memories(user_message)
  359. context["long_term_memory"] = {
  360. "relevant_conversations": relevant_memories
  361. }
  362. return context
  363. def _select_prompt(self, state):
  364. state_to_prompt_map = {
  365. DialogueState.GREETING: GENERAL_GREETING_PROMPT,
  366. DialogueState.CHITCHAT: GENERAL_GREETING_PROMPT,
  367. DialogueState.FAREWELL: GENERAL_GREETING_PROMPT
  368. }
  369. return state_to_prompt_map[state]
  370. def _select_coze_bot(self, state):
  371. state_to_bot_map = {
  372. DialogueState.GREETING: '7486112546798780425',
  373. DialogueState.CHITCHAT: '7491300566573301770',
  374. DialogueState.FAREWELL: '7491300566573301770'
  375. }
  376. return state_to_bot_map[state]
  377. def _create_system_message(self, prompt_context):
  378. prompt_template = self._select_prompt(self.current_state)
  379. prompt = prompt_template.format(**prompt_context)
  380. return {'role': 'system', 'content': prompt}
  381. def build_chat_configuration(
  382. self,
  383. user_message: Optional[str] = None,
  384. chat_service_type: ChatServiceType = ChatServiceType.OPENAI_COMPATIBLE
  385. ) -> Dict:
  386. """
  387. 参数:
  388. user_message: 当前用户消息,如果是主动交互则为None
  389. 返回:
  390. 消息列表
  391. """
  392. dialogue_history = self.history_dialogue_service.get_dialogue_history(self.staff_id, self.user_id)
  393. logging.debug("staff[{}], user[{}], dialogue_history: {}".format(
  394. self.staff_id, self.user_id, dialogue_history
  395. ))
  396. messages = []
  397. config = {}
  398. prompt_context = self.get_prompt_context(user_message)
  399. if chat_service_type == ChatServiceType.OPENAI_COMPATIBLE:
  400. system_message = self._create_system_message(prompt_context)
  401. messages.append(system_message)
  402. for entry in dialogue_history:
  403. role = entry['role']
  404. messages.append({
  405. "role": role,
  406. "content": entry["content"]
  407. })
  408. elif chat_service_type == ChatServiceType.COZE_CHAT:
  409. for entry in dialogue_history:
  410. if not entry['content']:
  411. logging.warning("staff[{}], user[{}], role[{}]: empty content in dialogue history".format(
  412. self.staff_id, self.user_id, entry['role']
  413. ))
  414. continue
  415. role = entry['role']
  416. if role == 'user':
  417. messages.append(cozepy.Message.build_user_question_text(entry["content"]))
  418. elif role == 'assistant':
  419. messages.append(cozepy.Message.build_assistant_answer(entry['content']))
  420. custom_variables = {}
  421. for k, v in prompt_context.items():
  422. custom_variables[k] = str(v)
  423. custom_variables.pop('user_profile', None)
  424. config['custom_variables'] = custom_variables
  425. config['bot_id'] = self._select_coze_bot(self.current_state)
  426. #FIXME(zhoutian): 这种方法并不可靠,需要结合状态来判断
  427. if self.current_state == DialogueState.GREETING and not messages:
  428. messages.append(cozepy.Message.build_user_question_text('请开始对话'))
  429. #FIXME(zhoutian): 临时报警
  430. if user_message and not messages:
  431. logging.error(f"staff[{self.staff_id}], user[{self.user_id}]: inconsistency in messages")
  432. config['messages'] = messages
  433. return config
  434. if __name__ == '__main__':
  435. state_cache = DialogueStateCache()
  436. state_cache.set_state('1688854492669990', '7881302581935903', DialogueState.CHITCHAT, DialogueState.GREETING)