dialogue_manager.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import random
  5. from enum import Enum
  6. from typing import Dict, List, Optional, Tuple, Any
  7. from datetime import datetime
  8. import time
  9. import textwrap
  10. import pymysql.cursors
  11. import cozepy
  12. from sqlalchemy.orm import sessionmaker, Session
  13. from pqai_agent import configs
  14. from pqai_agent.data_models.agent_push_record import AgentPushRecord
  15. from pqai_agent.logging_service import logger
  16. from pqai_agent.database import MySQLManager
  17. from pqai_agent import chat_service, prompt_templates
  18. from pqai_agent.history_dialogue_service import HistoryDialogueService
  19. from pqai_agent.chat_service import ChatServiceType
  20. from pqai_agent.mq_message import MessageType, MqMessage
  21. from pqai_agent.toolkit.lark_alert_for_human_intervention import LarkAlertForHumanIntervention
  22. from pqai_agent.toolkit.lark_sheet_record_for_human_intervention import LarkSheetRecordForHumanIntervention
  23. from pqai_agent.user_manager import UserManager
  24. from pqai_agent.utils import prompt_utils
  25. class DummyVectorMemoryManager:
  26. def __init__(self, user_id):
  27. pass
  28. def add_to_memory(self, conversation):
  29. pass
  30. def retrieve_relevant_memories(self, query, k=3):
  31. return []
  32. class DialogueState(int, Enum):
  33. INITIALIZED = 0
  34. GREETING = 1 # 问候状态
  35. CHITCHAT = 2 # 闲聊状态
  36. CLARIFICATION = 3 # 澄清状态
  37. FAREWELL = 4 # 告别状态
  38. HUMAN_INTERVENTION = 5 # 人工介入状态
  39. MESSAGE_AGGREGATING = 6 # 等待消息状态
  40. class TimeContext(Enum):
  41. EARLY_MORNING = "清晨" # 清晨 (5:00-7:59)
  42. MORNING = "上午" # 上午 (8:00-11:59)
  43. NOON = "中午" # 中午 (12:00-13:59)
  44. AFTERNOON = "下午" # 下午 (14:00-17:59)
  45. EVENING = "晚上" # 晚上 (18:00-21:59)
  46. NIGHT = "深夜" # 夜晚 (22:00-4:59)
  47. def __init__(self, description):
  48. self.description = description
  49. class DialogueStateChangeType(int, Enum):
  50. STATE = 0
  51. INTERACTION_TIME = 1
  52. DIALOGUE_HISTORY = 2
  53. class DialogueStateChange:
  54. def __init__(self, event_type: DialogueStateChangeType,old: Any, new: Any):
  55. self.event_type = event_type
  56. self.old = old
  57. self.new = new
  58. class DialogueStateCache:
  59. def __init__(self):
  60. self.config = configs.get()
  61. self.db = MySQLManager(self.config['storage']['agent_state']['mysql'])
  62. self.table = self.config['storage']['agent_state']['table']
  63. def get_state(self, staff_id: str, user_id: str) -> Tuple[DialogueState, DialogueState]:
  64. query = f"SELECT current_state, previous_state FROM {self.table} WHERE staff_id=%s AND user_id=%s"
  65. data = self.db.select(query, pymysql.cursors.DictCursor, (staff_id, user_id))
  66. if not data:
  67. logger.warning(f"staff[{staff_id}], user[{user_id}]: agent state not found")
  68. state = DialogueState.INITIALIZED
  69. previous_state = DialogueState.INITIALIZED
  70. self.set_state(staff_id, user_id, state, previous_state)
  71. else:
  72. state = DialogueState(data[0]['current_state'])
  73. previous_state = DialogueState(data[0]['previous_state'])
  74. return state, previous_state
  75. def set_state(self, staff_id: str, user_id: str, state: DialogueState, previous_state: DialogueState):
  76. if self.config.get('debug_flags', {}).get('disable_database_write', False):
  77. return
  78. query = f"INSERT INTO {self.table} (staff_id, user_id, current_state, previous_state)" \
  79. f" VALUES (%s, %s, %s, %s) " \
  80. f"ON DUPLICATE KEY UPDATE current_state=%s, previous_state=%s"
  81. rows = self.db.execute(query, (staff_id, user_id, state.value, previous_state.value, state.value, previous_state.value))
  82. logger.debug("staff[{}], user[{}]: set state: {}, previous state: {}, rows affected: {}"
  83. .format(staff_id, user_id, state, previous_state, rows))
  84. class DialogueManager:
  85. def __init__(self, staff_id: str, user_id: str, user_manager: UserManager, state_cache: DialogueStateCache,
  86. AgentDBSession: sessionmaker[Session]):
  87. config = configs.get()
  88. self.staff_id = staff_id
  89. self.user_id = user_id
  90. self.user_manager = user_manager
  91. self.state_cache = state_cache
  92. self.current_state = DialogueState.GREETING
  93. self.previous_state = DialogueState.INITIALIZED
  94. # 目前实际仅用作调试,拼装prompt时使用history_dialogue_service获取
  95. self.dialogue_history = []
  96. self.user_profile = self.user_manager.get_user_profile(user_id)
  97. self.staff_profile = self.user_manager.get_staff_profile(staff_id)
  98. # FIXME: 交互时间和对话记录都涉及到回滚
  99. self.last_interaction_time_ms = 0
  100. self.last_active_interaction_time_sec = 0
  101. self.human_intervention_triggered = False
  102. self.vector_memory = DummyVectorMemoryManager(user_id)
  103. self.message_aggregation_sec = config.get('agent_behavior', {}).get('message_aggregation_sec', 5)
  104. self.unprocessed_messages = []
  105. self.history_dialogue_service = HistoryDialogueService(
  106. config['storage']['history_dialogue']['api_base_url']
  107. )
  108. self.AgentDBSession = AgentDBSession
  109. self._recover_state()
  110. # 由于本地状态管理过于复杂,引入事务机制做状态回滚
  111. self._uncommited_state_change = []
  112. @staticmethod
  113. def get_time_context(current_hour=None) -> TimeContext:
  114. """获取当前时间上下文"""
  115. if not current_hour:
  116. current_hour = datetime.now().hour
  117. if 5 <= current_hour < 7:
  118. return TimeContext.EARLY_MORNING
  119. elif 7 <= current_hour < 11:
  120. return TimeContext.MORNING
  121. elif 11 <= current_hour < 14:
  122. return TimeContext.NOON
  123. elif 14 <= current_hour < 18:
  124. return TimeContext.AFTERNOON
  125. elif 18 <= current_hour < 22:
  126. return TimeContext.EVENING
  127. else:
  128. return TimeContext.NIGHT
  129. def is_valid(self):
  130. if not self.staff_profile.get('name', None) and not self.staff_profile.get('agent_name', None):
  131. return False
  132. return True
  133. def refresh_profile(self):
  134. self.staff_profile = self.user_manager.get_staff_profile(self.staff_id)
  135. def _recover_state(self):
  136. self.current_state, self.previous_state = self.state_cache.get_state(self.staff_id, self.user_id)
  137. # 从数据库恢复对话状态
  138. minutes_to_get = 5 * 24 * 60
  139. self.dialogue_history = self.history_dialogue_service.get_dialogue_history(
  140. self.staff_id, self.user_id, minutes_to_get)
  141. if self.dialogue_history:
  142. self.last_interaction_time_ms = self.dialogue_history[-1]['timestamp']
  143. if self.current_state == DialogueState.MESSAGE_AGGREGATING:
  144. # 需要恢复未处理对话,找到dialogue_history中最后未处理的user消息
  145. for entry in reversed(self.dialogue_history):
  146. if entry['role'] == 'user':
  147. self.unprocessed_messages.append(entry['content'])
  148. break
  149. else:
  150. # 默认设置
  151. self.last_interaction_time_ms = int(time.time() * 1000) - minutes_to_get * 60 * 1000
  152. with self.AgentDBSession() as session:
  153. # 读取数据库中的最后一次交互时间
  154. query = session.query(AgentPushRecord).filter(
  155. AgentPushRecord.staff_id == self.staff_id,
  156. AgentPushRecord.user_id == self.user_id
  157. ).order_by(AgentPushRecord.timestamp.desc()).first()
  158. if query:
  159. self.last_active_interaction_time_sec = query.timestamp
  160. fmt_time = datetime.fromtimestamp(self.last_interaction_time_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
  161. logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: state: {self.current_state.name}, last_interaction: {fmt_time}")
  162. def update_interaction_time(self, timestamp_ms: int):
  163. self._uncommited_state_change.append(DialogueStateChange(
  164. DialogueStateChangeType.INTERACTION_TIME,
  165. self.last_interaction_time_ms,
  166. timestamp_ms
  167. ))
  168. self.last_interaction_time_ms = timestamp_ms
  169. def append_dialogue_history(self, message: Dict):
  170. self._uncommited_state_change.append(DialogueStateChange(
  171. DialogueStateChangeType.DIALOGUE_HISTORY,
  172. None,
  173. 1
  174. ))
  175. self.dialogue_history.append(message)
  176. def persist_state(self):
  177. """持久化对话状态,只有当前状态处理成功后才应该做持久化"""
  178. self.commit()
  179. config = configs.get()
  180. if config.get('debug_flags', {}).get('disable_database_write', False):
  181. return
  182. self.state_cache.set_state(self.staff_id, self.user_id, self.current_state, self.previous_state)
  183. def rollback_state(self):
  184. logger.info(f"staff[{self.staff_id}], user[{self.user_id}]: reverse state")
  185. for entry in reversed(self._uncommited_state_change):
  186. if entry.event_type == DialogueStateChangeType.STATE:
  187. self.current_state, self.previous_state = entry.old
  188. elif entry.event_type == DialogueStateChangeType.INTERACTION_TIME:
  189. self.last_interaction_time_ms = entry.old
  190. elif entry.event_type == DialogueStateChangeType.DIALOGUE_HISTORY:
  191. self.dialogue_history.pop()
  192. else:
  193. logger.error(f"unimplemented type: [{entry.event_type}]")
  194. self._uncommited_state_change.clear()
  195. def commit(self):
  196. self._uncommited_state_change.clear()
  197. def do_state_change(self, state: DialogueState):
  198. state_backup = (self.current_state, self.previous_state)
  199. if self.current_state == DialogueState.MESSAGE_AGGREGATING:
  200. # MESSAGE_AGGREGATING不能成为previous_state,仅使用state_backup做回退
  201. self.current_state = state
  202. else:
  203. self.previous_state = self.current_state
  204. self.current_state = state
  205. self._uncommited_state_change.append(DialogueStateChange(
  206. DialogueStateChangeType.STATE,
  207. state_backup,
  208. (self.current_state, self.previous_state)
  209. ))
  210. def update_state(self, message: MqMessage) -> Tuple[bool, Optional[str]]:
  211. """根据用户消息更新对话状态,并返回是否需要发起回复 及下一条需处理的用户消息"""
  212. message_text = message.content
  213. message_ts = message.sendTime
  214. # 如果当前已经是人工介入状态,根据消息类型决定保持/退出
  215. if self.current_state == DialogueState.HUMAN_INTERVENTION:
  216. if message.type == MessageType.HUMAN_INTERVENTION_END:
  217. self.resume_from_human_intervention()
  218. # 恢复状态,但无需Agent产生回复
  219. return False, None
  220. else:
  221. self.append_dialogue_history({
  222. "role": "user",
  223. "content": message_text,
  224. "timestamp": int(time.time() * 1000),
  225. "state": self.current_state.name
  226. })
  227. return False, message_text
  228. if message.type == MessageType.ENTER_HUMAN_INTERVENTION:
  229. logger.info(f"staff[{self.staff_id}], user[{self.user_id}]: human intervention triggered")
  230. self.do_state_change(DialogueState.HUMAN_INTERVENTION)
  231. return False, None
  232. # 检查是否处于消息聚合状态
  233. if self.current_state == DialogueState.MESSAGE_AGGREGATING:
  234. # 收到的是特殊定时触发的空消息,且在聚合中,且已经超时,继续处理
  235. if message.type == MessageType.AGGREGATION_TRIGGER:
  236. if message_ts - self.last_interaction_time_ms > self.message_aggregation_sec * 1000:
  237. logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: exit aggregation waiting")
  238. else:
  239. logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: continue aggregation waiting")
  240. return False, message_text
  241. else:
  242. # 非空消息,更新最后交互时间,保持消息聚合状态
  243. if message_text:
  244. self.unprocessed_messages.append(message_text)
  245. self.update_interaction_time(message_ts)
  246. return False, message_text
  247. else:
  248. if message.type == MessageType.AGGREGATION_TRIGGER:
  249. # 未在聚合状态中,收到的聚合触发消息为过时消息,不应当处理
  250. logger.warning(f"staff[{self.staff_id}], user[{self.user_id}]: received {message.type} in state {self.current_state}")
  251. return False, None
  252. if message.type == MessageType.HUMAN_INTERVENTION_END:
  253. # 未在人工介入状态中,收到的人工介入结束事件为过时消息,不应当处理
  254. logger.warning(f"staff[{self.staff_id}], user[{self.user_id}]: received {message.type} in state {self.current_state}")
  255. return False, None
  256. if message.type != MessageType.AGGREGATION_TRIGGER and self.message_aggregation_sec > 0:
  257. # 收到有内容的用户消息,切换到消息聚合状态
  258. self.do_state_change(DialogueState.MESSAGE_AGGREGATING)
  259. self.unprocessed_messages.append(message_text)
  260. # 更新最后交互时间
  261. if message_text:
  262. self.update_interaction_time(message_ts)
  263. return False, message_text
  264. # 获得未处理的聚合消息,并清空未处理队列
  265. if message_text:
  266. self.unprocessed_messages.append(message_text)
  267. if self.unprocessed_messages:
  268. message_text = '\n'.join(self.unprocessed_messages)
  269. self.unprocessed_messages.clear()
  270. # 实际上这里message_text并不会被最终送入LLM,只是用来做状态判断
  271. # 根据消息内容和当前状态确定新状态
  272. new_state = self._determine_state_from_message(message_text)
  273. # 更新状态
  274. self.do_state_change(new_state)
  275. if message_text:
  276. self.update_interaction_time(message_ts)
  277. self.append_dialogue_history({
  278. "role": "user",
  279. "content": message_text,
  280. "timestamp": message_ts,
  281. "state": self.current_state.name
  282. })
  283. return True, message_text
  284. def _determine_state_from_message(self, message_text: Optional[str]) -> DialogueState:
  285. """根据消息内容确定对话状态"""
  286. if not message_text:
  287. logger.warning(f"staff[{self.staff_id}], user[{self.user_id}]: empty message")
  288. return self.current_state
  289. # 简单的规则-关键词匹配
  290. message_lower = message_text.lower()
  291. # 问候检测
  292. greeting_keywords = ["你好", "早上好", "中午好", "晚上好", "嗨", "在吗"]
  293. if any(keyword in message_lower for keyword in greeting_keywords):
  294. return DialogueState.GREETING
  295. # 告别检测
  296. farewell_keywords = ["再见", "拜拜", "晚安", "明天见", "回头见"]
  297. if any(keyword in message_lower for keyword in farewell_keywords):
  298. return DialogueState.FAREWELL
  299. # 澄清请求
  300. # clarification_keywords = ["没明白", "不明白", "没听懂", "不懂", "什么意思", "再说一遍"]
  301. # if any(keyword in message_lower for keyword in clarification_keywords):
  302. # return DialogueState.CLARIFICATION
  303. # 默认为闲聊状态
  304. return DialogueState.CHITCHAT
  305. def _send_alert(self, alert_type: str, reason: Optional[str] = None) -> None:
  306. time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
  307. staff_info = f"{self.staff_profile.get('name', '未知')}[{self.staff_id}]"
  308. user_info = f"{self.user_profile.get('nickname', '未知')}[{self.user_id}]"
  309. alert_message = f"""
  310. {alert_type}告警
  311. 员工: {staff_info}
  312. 用户: {user_info}
  313. 时间: {time_str}
  314. 原因:{reason if reason else "未知"}
  315. 最近对话:
  316. """
  317. alert_message = textwrap.dedent(alert_message).strip()
  318. # 添加最近的对话记录
  319. recent_dialogues = self.dialogue_history[-5:]
  320. dialogue_to_send = []
  321. role_map = {'assistant': '客服', 'user': '用户'}
  322. for dialogue in recent_dialogues:
  323. if not dialogue['content']:
  324. continue
  325. role = dialogue['role']
  326. if role not in role_map:
  327. continue
  328. dialogue_to_send.append(f"[{role_map[role]}]{dialogue['content']}")
  329. alert_message += '\n'.join(dialogue_to_send)
  330. if alert_type == '人工介入':
  331. ack_url = "http://ai-wechat-hook.piaoquantv.com/manage/insertEvent?" \
  332. f"sender={self.user_id}&receiver={self.staff_id}&type={MessageType.HUMAN_INTERVENTION_END.value}&content=OPERATION"
  333. else:
  334. ack_url = None
  335. LarkAlertForHumanIntervention().send_lark_alert_for_human_intervention(alert_message, ack_url)
  336. if alert_type == '人工介入':
  337. LarkSheetRecordForHumanIntervention().send_lark_sheet_record_for_human_intervention(
  338. staff_info, user_info, '\n'.join(dialogue_to_send), reason
  339. )
  340. def resume_from_human_intervention(self) -> None:
  341. """从人工介入状态恢复"""
  342. if self.current_state == DialogueState.HUMAN_INTERVENTION:
  343. self.do_state_change(DialogueState.CHITCHAT)
  344. def generate_response(self, llm_response: str) -> Optional[str]:
  345. """
  346. 处理LLM的响应,更新对话状态和对话历史。
  347. 注意:所有的LLM响应都必须经过这个函数来处理!
  348. :param llm_response:
  349. :return:
  350. """
  351. if '<人工介入>' in llm_response:
  352. reason = llm_response.replace('<人工介入>', '')
  353. logger.warning(f'staff[{self.staff_id}], user[{self.user_id}]: human intervention triggered, reason: {reason}')
  354. self.do_state_change(DialogueState.HUMAN_INTERVENTION)
  355. self._send_alert('人工介入', reason)
  356. return None
  357. if '<结束>' in llm_response or '<负向情绪结束>' in llm_response:
  358. logger.warning(f'staff[{self.staff_id}], user[{self.user_id}]: conversation ended')
  359. self.do_state_change(DialogueState.FAREWELL)
  360. if '<负向情绪结束>' in llm_response:
  361. self._send_alert("用户负向情绪")
  362. return None
  363. """根据当前状态处理LLM响应,如果处于人工介入状态则返回None"""
  364. # 如果处于人工介入状态,不生成回复
  365. if self.current_state == DialogueState.HUMAN_INTERVENTION:
  366. return None
  367. # 记录响应到对话历史
  368. message_ts = int(time.time() * 1000)
  369. self.append_dialogue_history({
  370. "role": "assistant",
  371. "type": MessageType.TEXT,
  372. "content": llm_response,
  373. "timestamp": message_ts,
  374. "state": self.current_state.name
  375. })
  376. self.update_interaction_time(message_ts)
  377. return llm_response
  378. def generate_multimodal_response(self, item: Dict) -> Optional[Dict]:
  379. """
  380. 处理LLM的多模态响应,更新对话状态和对话历史。
  381. 注意:所有的LLM多模态响应都必须经过这个函数来处理!
  382. :param item: 包含多模态内容的字典
  383. :return: None
  384. """
  385. if self.current_state == DialogueState.HUMAN_INTERVENTION:
  386. return None
  387. raw_type = item.get("type", "text")
  388. if isinstance(raw_type, str):
  389. item["type"] = MessageType.from_str(raw_type)
  390. if item["type"] == MessageType.TEXT:
  391. if '<人工介入>' in item["content"]:
  392. reason = item["content"].replace('<人工介入>', '')
  393. logger.warning(f'staff[{self.staff_id}], user[{self.user_id}]: human intervention triggered, reason: {reason}')
  394. self.do_state_change(DialogueState.HUMAN_INTERVENTION)
  395. self._send_alert('人工介入', reason)
  396. return None
  397. if '<结束>' in item["content"] or '<负向情绪结束>' in item["content"]:
  398. logger.warning(f'staff[{self.staff_id}], user[{self.user_id}]: conversation ended')
  399. self.do_state_change(DialogueState.FAREWELL)
  400. if '<负向情绪结束>' in item["content"]:
  401. self._send_alert("用户负向情绪")
  402. return None
  403. # 记录响应到对话历史
  404. message_ts = int(time.time() * 1000)
  405. self.append_dialogue_history({
  406. "role": "assistant",
  407. "type": item["type"],
  408. "content": item["content"],
  409. "timestamp": message_ts,
  410. "state": self.current_state.name
  411. })
  412. self.update_interaction_time(message_ts)
  413. return item
  414. def _get_hours_since_last_interaction(self, precision: int = -1):
  415. time_diff = (time.time() * 1000) - self.last_interaction_time_ms
  416. hours_passed = time_diff / 1000 / 3600
  417. if precision >= 0:
  418. return round(hours_passed, precision)
  419. return hours_passed
  420. def update_last_active_interaction_time(self, timestamp_sec: int):
  421. # 只需更新本地时间,重启时可从数据库恢复
  422. self.last_active_interaction_time_sec = timestamp_sec
  423. def should_initiate_conversation(self) -> bool:
  424. """判断是否应该主动发起对话"""
  425. # 如果处于人工介入状态,不应主动发起对话
  426. if self.current_state == DialogueState.HUMAN_INTERVENTION:
  427. return False
  428. hours_passed = self._get_hours_since_last_interaction()
  429. # 获取当前时间上下文
  430. time_context = self.get_time_context()
  431. # 根据用户交互频率偏好设置不同的阈值
  432. interaction_frequency = self.user_profile.get("interaction_frequency", "medium")
  433. if interaction_frequency == 'stopped':
  434. return False
  435. # 设置不同偏好的交互时间阈值(小时)
  436. thresholds = {
  437. "low": 48,
  438. "medium": 24,
  439. "high": 12
  440. }
  441. threshold = thresholds.get(interaction_frequency, 24)
  442. #FIXME 05-21 临时策略,两次主动发起至少48小时
  443. if time.time() - self.last_active_interaction_time_sec < 2 * 24 * 3600:
  444. logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: last active interaction time too short")
  445. return False
  446. if hours_passed < threshold:
  447. return False
  448. # 根据时间上下文决定主动交互的状态
  449. if self.is_time_suitable_for_active_conversation(time_context):
  450. return True
  451. return False
  452. @staticmethod
  453. def is_time_suitable_for_active_conversation(time_context=None) -> bool:
  454. if configs.get_env() == 'dev':
  455. return True
  456. if not time_context:
  457. time_context = DialogueManager.get_time_context()
  458. if time_context in [TimeContext.MORNING,
  459. TimeContext.NOON, TimeContext.AFTERNOON]:
  460. return True
  461. return False
  462. def is_in_human_intervention(self) -> bool:
  463. """检查是否处于人工介入状态"""
  464. return self.current_state == DialogueState.HUMAN_INTERVENTION
  465. def get_prompt_context(self, user_message) -> Dict:
  466. # 获取当前时间上下文
  467. time_context = self.get_time_context()
  468. # 刷新用户画像
  469. self.user_profile = self.user_manager.get_user_profile(self.user_id)
  470. # 刷新员工画像(不一定需要)
  471. self.staff_profile = self.user_manager.get_staff_profile(self.staff_id)
  472. # 员工画像添加前缀,避免冲突,实现Coze Prompt模板的平滑升级
  473. legacy_staff_profile = {}
  474. for key in self.staff_profile:
  475. legacy_staff_profile[f'agent_{key}'] = self.staff_profile[key]
  476. current_datetime = datetime.now()
  477. context = {
  478. "current_state": self.current_state.name,
  479. "previous_state": self.previous_state.name,
  480. "current_time_period": time_context.description,
  481. "current_hour": current_datetime.hour,
  482. "current_time": current_datetime.strftime("%H:%M:%S"),
  483. "current_date": current_datetime.strftime("%Y-%m-%d"),
  484. "current_datetime": current_datetime.strftime("%Y-%m-%d %H:%M:%S"),
  485. "last_interaction_interval": self._get_hours_since_last_interaction(2),
  486. "if_first_interaction": True if self.previous_state == DialogueState.INITIALIZED else False,
  487. "if_active_greeting": False if user_message else True,
  488. "formatted_staff_profile": prompt_utils.format_agent_profile(self.staff_profile),
  489. "formatted_user_profile": prompt_utils.format_user_profile(self.user_profile),
  490. **self.user_profile,
  491. **legacy_staff_profile
  492. }
  493. # 获取长期记忆
  494. relevant_memories = self.vector_memory.retrieve_relevant_memories(user_message)
  495. context["long_term_memory"] = {
  496. "relevant_conversations": relevant_memories
  497. }
  498. return context
  499. @staticmethod
  500. def _select_prompt(state):
  501. state_to_prompt_map = {
  502. DialogueState.GREETING: prompt_templates.GENERAL_GREETING_PROMPT,
  503. DialogueState.CHITCHAT: prompt_templates.CHITCHAT_PROMPT_COZE,
  504. DialogueState.FAREWELL: prompt_templates.GENERAL_GREETING_PROMPT
  505. }
  506. return state_to_prompt_map[state]
  507. @staticmethod
  508. def _select_coze_bot(state, dialogue: List[Dict], multimodal=False):
  509. state_to_bot_map = {
  510. DialogueState.GREETING: '7486112546798780425',
  511. DialogueState.CHITCHAT: '7491300566573301770',
  512. DialogueState.FAREWELL: '7491300566573301770',
  513. }
  514. if multimodal:
  515. state_to_bot_map = {
  516. DialogueState.GREETING: '7496772218198900770',
  517. DialogueState.CHITCHAT: '7495692989504438308',
  518. DialogueState.FAREWELL: '7491300566573301770',
  519. }
  520. return state_to_bot_map[state]
  521. @staticmethod
  522. def need_multimodal_model(dialogue: List[Dict], max_message_to_use: int = 10):
  523. # 当前仅为简单实现
  524. recent_messages = dialogue[-max_message_to_use:]
  525. ret = False
  526. for entry in recent_messages:
  527. if entry.get('type') in (MessageType.IMAGE_GW, MessageType.IMAGE_QW, MessageType.GIF):
  528. ret = True
  529. break
  530. return ret
  531. def _create_system_message(self, prompt_context):
  532. prompt_template = self._select_prompt(self.current_state)
  533. prompt = prompt_template.format(**prompt_context)
  534. return {'role': 'system', 'content': prompt}
  535. @staticmethod
  536. def compose_chat_messages_openai_compatible(dialogue_history, current_time, multimodal=False):
  537. messages = []
  538. for entry in dialogue_history:
  539. role = entry['role']
  540. msg_type = entry.get('type', MessageType.TEXT)
  541. fmt_time = DialogueManager.format_timestamp(entry['timestamp'])
  542. if msg_type in (MessageType.IMAGE_GW, MessageType.IMAGE_QW, MessageType.GIF):
  543. if multimodal:
  544. messages.append({
  545. "role": role,
  546. "content": [
  547. {"type": "image_url", "image_url": {"url": entry["content"]}}
  548. ]
  549. })
  550. else:
  551. logger.warning("Image in non-multimodal mode")
  552. messages.append({
  553. "role": role,
  554. "content": "[{}] {}".format(fmt_time, '[图片]')
  555. })
  556. else:
  557. messages.append({
  558. "role": role,
  559. "content": '[{}] {}'.format(fmt_time, entry["content"])
  560. })
  561. # 添加一条前缀用于 约束时间场景
  562. msg_prefix = '[{}]'.format(current_time)
  563. messages.append({'role': 'assistant', 'content': msg_prefix})
  564. return messages
  565. @staticmethod
  566. def compose_chat_messages_coze(dialogue_history, current_time, staff_id, user_id):
  567. messages = []
  568. # 如果system后的第1条消息不为user,需要在最开始补一条user消息,否则会吞assistant消息
  569. if len(dialogue_history) > 0 and dialogue_history[0]['role'] != 'user':
  570. fmt_time = DialogueManager.format_timestamp(dialogue_history[0]['timestamp'])
  571. messages.append(cozepy.Message.build_user_question_text(f'[{fmt_time}] '))
  572. # coze最后一条消息必须为user,且可能吞掉连续的user消息,故强制增加一条空消息(可参与合并)
  573. dialogue_history.append({
  574. 'role': 'user',
  575. 'content': ' ',
  576. 'timestamp': int(datetime.strptime(current_time, '%Y-%m-%d %H:%M:%S').timestamp() * 1000),
  577. })
  578. # 将连续的同一角色的消息做聚合,避免coze吞消息
  579. messages_to_aggr = []
  580. objects_to_aggr = []
  581. last_message_role = None
  582. for entry in dialogue_history:
  583. if not entry['content']:
  584. logger.warning("staff[{}], user[{}], role[{}]: empty content in dialogue history".format(
  585. staff_id, user_id, entry['role']
  586. ))
  587. continue
  588. role = entry['role']
  589. if role != last_message_role:
  590. if objects_to_aggr:
  591. if last_message_role != 'user':
  592. pass
  593. else:
  594. text_message = '\n'.join(messages_to_aggr)
  595. object_string_list = []
  596. for object_entry in objects_to_aggr:
  597. # FIXME: 其它消息类型的支持
  598. object_string_list.append(cozepy.MessageObjectString.build_image(file_url=object_entry['content']))
  599. object_string_list.append(cozepy.MessageObjectString.build_text(text_message))
  600. messages.append(cozepy.Message.build_user_question_objects(object_string_list))
  601. elif messages_to_aggr:
  602. aggregated_message = '\n'.join(messages_to_aggr)
  603. messages.append(DialogueManager.build_chat_message(
  604. last_message_role, aggregated_message, ChatServiceType.COZE_CHAT))
  605. objects_to_aggr = []
  606. messages_to_aggr = []
  607. last_message_role = role
  608. if entry.get('type', MessageType.TEXT) in (MessageType.IMAGE_GW, MessageType.IMAGE_QW, MessageType.GIF):
  609. # 多模态消息必须用特殊的聚合方式,一个object_string数组中只能有一个文字消息,但可以有多个图片
  610. if role == 'user':
  611. objects_to_aggr.append(entry)
  612. else:
  613. logger.warning("staff[{}], user[{}]: unsupported message type [{}] in assistant role".format(
  614. staff_id, user_id, entry['type']
  615. ))
  616. else:
  617. messages_to_aggr.append(DialogueManager.format_dialogue_content(entry))
  618. # 如果有未聚合的object消息,需要特殊处理
  619. if objects_to_aggr:
  620. if last_message_role != 'user':
  621. pass
  622. else:
  623. text_message = '\n'.join(messages_to_aggr)
  624. object_string_list = []
  625. for object_entry in objects_to_aggr:
  626. # FIXME: 其它消息类型的支持
  627. object_string_list.append(cozepy.MessageObjectString.build_image(file_url=object_entry['content']))
  628. object_string_list.append(cozepy.MessageObjectString.build_text(text_message))
  629. messages.append(cozepy.Message.build_user_question_objects(object_string_list))
  630. elif messages_to_aggr:
  631. aggregated_message = '\n'.join(messages_to_aggr)
  632. messages.append(DialogueManager.build_chat_message(
  633. last_message_role, aggregated_message, ChatServiceType.COZE_CHAT))
  634. # 从末尾开始往前遍历,如果assistant曾经回复“无法回答”,则清除当前消息和前一条用户消息
  635. idx = len(messages) - 1
  636. while idx >= 0:
  637. if messages[idx].role == 'assistant' and '无法回答' in messages[idx].content:
  638. messages.pop(idx)
  639. idx -= 1
  640. if idx >= 0:
  641. messages.pop(idx)
  642. idx -= 1
  643. else:
  644. idx -= 1
  645. return messages
  646. def build_active_greeting_config(self, user_tags: List[str]):
  647. # FIXME: 这里的抽象不好,短期支持人为配置实验
  648. # 由于产运要求,指定使用GPT-4o模型
  649. chat_config = {'user_id': self.user_id, 'model_name': chat_service.OPENAI_MODEL_GPT_4o}
  650. prompt_context = self.get_prompt_context(None)
  651. current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  652. system_message = {'role': 'system', 'content': 'You are a helpful AI assistant.'}
  653. # TODO: 随机选择一个prompt 或 带策略选择 或根据用户标签选择
  654. # TODO:需要区分用户是否有历史交互、是否发送过相似内容
  655. greeting_prompts = [
  656. prompt_templates.GREETING_WITH_IMAGE_GAME,
  657. prompt_templates.GREETING_WITH_NAME_POETRY,
  658. prompt_templates.GREETING_WITH_AVATAR_STORY
  659. ]
  660. # 默认随机选择
  661. selected_prompt = greeting_prompts[random.randint(0, len(greeting_prompts) - 1)]
  662. # 实验配置
  663. tag_to_greeting_map = {
  664. '04W4-AA-1': prompt_templates.GREETING_WITH_NAME_POETRY,
  665. '04W4-AA-2': prompt_templates.GREETING_WITH_AVATAR_STORY,
  666. '04W4-AA-3': prompt_templates.GREETING_WITH_INTEREST_QUERY,
  667. '04W4-AA-4': prompt_templates.GREETING_WITH_CALENDAR,
  668. }
  669. for tag in user_tags:
  670. if tag in tag_to_greeting_map:
  671. selected_prompt = tag_to_greeting_map[tag]
  672. prompt = selected_prompt.format(**prompt_context)
  673. user_message = {'role': 'user', 'content': prompt}
  674. messages = [system_message, user_message]
  675. if selected_prompt in (
  676. prompt_templates.GREETING_WITH_AVATAR_STORY,
  677. prompt_templates.GREETING_WITH_INTEREST_QUERY,
  678. ):
  679. messages.append({
  680. "role": 'user',
  681. "content": [
  682. {"type": "image_url", "image_url": {"url": self.user_profile['avatar']}}
  683. ]
  684. })
  685. chat_config['use_multimodal_model'] = True
  686. chat_config['messages'] = messages
  687. return chat_config
  688. def build_chat_configuration(
  689. self,
  690. user_message: Optional[str] = None,
  691. chat_service_type: ChatServiceType = ChatServiceType.OPENAI_COMPATIBLE,
  692. overwrite_context: Optional[Dict] = None
  693. ) -> Dict:
  694. """
  695. 参数:
  696. user_message: 当前用户消息,如果是主动交互则为None
  697. 返回:
  698. 消息列表
  699. """
  700. dialogue_history = self.history_dialogue_service.get_dialogue_history(self.staff_id, self.user_id)
  701. logger.debug("staff[{}], user[{}], recent dialogue_history: {}".format(
  702. self.staff_id, self.user_id, dialogue_history[-20:]
  703. ))
  704. messages = []
  705. config = {
  706. 'user_id': self.user_id
  707. }
  708. prompt_context = self.get_prompt_context(user_message)
  709. if overwrite_context:
  710. prompt_context.update(overwrite_context)
  711. # FIXME(zhoutian): time in string type
  712. current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  713. if overwrite_context and 'current_time' in overwrite_context:
  714. current_time = overwrite_context.get('current_time')
  715. need_multimodal = self.need_multimodal_model(dialogue_history)
  716. config['use_multimodal_model'] = need_multimodal
  717. if chat_service_type == ChatServiceType.OPENAI_COMPATIBLE:
  718. system_message = self._create_system_message(prompt_context)
  719. messages.append(system_message)
  720. messages.extend(self.compose_chat_messages_openai_compatible(dialogue_history, current_time, need_multimodal))
  721. elif chat_service_type == ChatServiceType.COZE_CHAT:
  722. dialogue_history = dialogue_history[-95:] # Coze最多支持100条,还需要附加系统消息
  723. messages = self.compose_chat_messages_coze(dialogue_history, current_time, self.staff_id, self.user_id)
  724. custom_variables = {}
  725. for k, v in prompt_context.items():
  726. custom_variables[k] = str(v)
  727. custom_variables.pop('user_profile', None)
  728. config['custom_variables'] = custom_variables
  729. config['bot_id'] = self._select_coze_bot(self.current_state, dialogue_history, need_multimodal)
  730. #FIXME(zhoutian): 临时报警
  731. if user_message and not messages:
  732. logger.error(f"staff[{self.staff_id}], user[{self.user_id}]: inconsistency in messages")
  733. config['messages'] = messages
  734. return config
  735. @staticmethod
  736. def format_timestamp(timestamp_ms):
  737. return datetime.fromtimestamp(timestamp_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
  738. @staticmethod
  739. def format_dialogue_content(dialogue_entry):
  740. fmt_time = DialogueManager.format_timestamp(dialogue_entry['timestamp'])
  741. content = '[{}] {}'.format(fmt_time, dialogue_entry['content'])
  742. return content
  743. @staticmethod
  744. def build_chat_message(role, content, chat_service_type: ChatServiceType):
  745. if chat_service_type == ChatServiceType.COZE_CHAT:
  746. if role == 'user':
  747. return cozepy.Message.build_user_question_text(content)
  748. elif role == 'assistant':
  749. return cozepy.Message.build_assistant_answer(content)
  750. else:
  751. return {'role': role, 'content': content}
  752. if __name__ == '__main__':
  753. state_cache = DialogueStateCache()
  754. state_cache.set_state('1688854492669990', '7881302581935903', DialogueState.CHITCHAT, DialogueState.GREETING)