dialogue_manager.py 32 KB


  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import random
  5. from enum import Enum, auto
  6. from typing import Dict, List, Optional, Tuple, Any
  7. from datetime import datetime
  8. import time
  9. import prompt_templates
  10. from logging_service import logger
  11. import pymysql.cursors
  12. import configs
  13. import cozepy
  14. from database import MySQLManager
  15. from history_dialogue_service import HistoryDialogueService
  16. from chat_service import ChatServiceType
  17. from message import MessageType, Message
  18. from user_manager import UserManager
  19. from prompt_templates import *
  20. class DummyVectorMemoryManager:
  21. def __init__(self, user_id):
  22. pass
  23. def add_to_memory(self, conversation):
  24. pass
  25. def retrieve_relevant_memories(self, query, k=3):
  26. return []
  27. class DialogueState(int, Enum):
  28. INITIALIZED = 0
  29. GREETING = 1 # 问候状态
  30. CHITCHAT = 2 # 闲聊状态
  31. CLARIFICATION = 3 # 澄清状态
  32. FAREWELL = 4 # 告别状态
  33. HUMAN_INTERVENTION = 5 # 人工介入状态
  34. MESSAGE_AGGREGATING = 6 # 等待消息状态
  35. class TimeContext(Enum):
  36. EARLY_MORNING = "清晨" # 清晨 (5:00-7:59)
  37. MORNING = "上午" # 上午 (8:00-11:59)
  38. NOON = "中午" # 中午 (12:00-13:59)
  39. AFTERNOON = "下午" # 下午 (14:00-17:59)
  40. EVENING = "晚上" # 晚上 (18:00-21:59)
  41. NIGHT = "深夜" # 夜晚 (22:00-4:59)
  42. def __init__(self, description):
  43. self.description = description
  44. class DialogueStateChangeType(int, Enum):
  45. STATE = 0
  46. INTERACTION_TIME = 1
  47. DIALOGUE_HISTORY = 2
  48. class DialogueStateChange:
  49. def __init__(self, event_type: DialogueStateChangeType,old: Any, new: Any):
  50. self.event_type = event_type
  51. self.old = old
  52. self.new = new
  53. class DialogueStateCache:
  54. def __init__(self):
  55. self.config = configs.get()
  56. self.db = MySQLManager(self.config['storage']['agent_state']['mysql'])
  57. self.table = self.config['storage']['agent_state']['table']
  58. def get_state(self, staff_id: str, user_id: str) -> Tuple[DialogueState, DialogueState]:
  59. query = f"SELECT current_state, previous_state FROM {self.table} WHERE staff_id=%s AND user_id=%s"
  60. data = self.db.select(query, pymysql.cursors.DictCursor, (staff_id, user_id))
  61. if not data:
  62. logger.warning(f"staff[{staff_id}], user[{user_id}]: agent state not found")
  63. state = DialogueState.INITIALIZED
  64. previous_state = DialogueState.INITIALIZED
  65. self.set_state(staff_id, user_id, state, previous_state)
  66. else:
  67. state = DialogueState(data[0]['current_state'])
  68. previous_state = DialogueState(data[0]['previous_state'])
  69. return state, previous_state
  70. def set_state(self, staff_id: str, user_id: str, state: DialogueState, previous_state: DialogueState):
  71. if self.config.get('debug_flags', {}).get('disable_database_write', False):
  72. return
  73. query = f"INSERT INTO {self.table} (staff_id, user_id, current_state, previous_state)" \
  74. f" VALUES (%s, %s, %s, %s) " \
  75. f"ON DUPLICATE KEY UPDATE current_state=%s, previous_state=%s"
  76. rows = self.db.execute(query, (staff_id, user_id, state.value, previous_state.value, state.value, previous_state.value))
  77. logger.debug("staff[{}], user[{}]: set state: {}, previous state: {}, rows affected: {}"
  78. .format(staff_id, user_id, state, previous_state, rows))
  79. class DialogueManager:
  80. def __init__(self, staff_id: str, user_id: str, user_manager: UserManager, state_cache: DialogueStateCache):
  81. config = configs.get()
  82. self.staff_id = staff_id
  83. self.user_id = user_id
  84. self.user_manager = user_manager
  85. self.state_cache = state_cache
  86. self.current_state = DialogueState.GREETING
  87. self.previous_state = DialogueState.INITIALIZED
  88. # 目前实际仅用作调试,拼装prompt时使用history_dialogue_service获取
  89. self.dialogue_history = []
  90. self.user_profile = self.user_manager.get_user_profile(user_id)
  91. self.staff_profile = self.user_manager.get_staff_profile(staff_id)
  92. # FIXME: 交互时间和对话记录都涉及到回滚
  93. self.last_interaction_time = 0
  94. self.consecutive_clarifications = 0
  95. self.complex_request_counter = 0
  96. self.human_intervention_triggered = False
  97. self.vector_memory = DummyVectorMemoryManager(user_id)
  98. self.message_aggregation_sec = config.get('agent_behavior', {}).get('message_aggregation_sec', 5)
  99. self.unprocessed_messages = []
  100. self.history_dialogue_service = HistoryDialogueService(
  101. config['storage']['history_dialogue']['api_base_url']
  102. )
  103. self._recover_state()
  104. # 由于本地状态管理过于复杂,引入事务机制做状态回滚
  105. self._uncommited_state_change = []
  106. @staticmethod
  107. def get_time_context(current_hour=None) -> TimeContext:
  108. """获取当前时间上下文"""
  109. if not current_hour:
  110. current_hour = datetime.now().hour
  111. if 5 <= current_hour < 8:
  112. return TimeContext.EARLY_MORNING
  113. elif 8 <= current_hour < 12:
  114. return TimeContext.MORNING
  115. elif 12 <= current_hour < 14:
  116. return TimeContext.NOON
  117. elif 14 <= current_hour < 18:
  118. return TimeContext.AFTERNOON
  119. elif 18 <= current_hour < 22:
  120. return TimeContext.EVENING
  121. else:
  122. return TimeContext.NIGHT
  123. def _recover_state(self):
  124. self.current_state, self.previous_state = self.state_cache.get_state(self.staff_id, self.user_id)
  125. # 从数据库恢复对话状态
  126. self.dialogue_history = self.history_dialogue_service.get_dialogue_history(self.staff_id, self.user_id)
  127. if self.dialogue_history:
  128. self.last_interaction_time = self.dialogue_history[-1]['timestamp']
  129. else:
  130. # 默认设置为24小时前
  131. self.last_interaction_time = int(time.time() * 1000) - 24 * 3600 * 1000
  132. time_for_read = datetime.fromtimestamp(self.last_interaction_time / 1000).strftime("%Y-%m-%d %H:%M:%S")
  133. logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: state: {self.current_state.name}, last_interaction: {time_for_read}")
  134. def update_interaction_time(self, timestamp_ms: int):
  135. self._uncommited_state_change.append(DialogueStateChange(
  136. DialogueStateChangeType.INTERACTION_TIME,
  137. self.last_interaction_time,
  138. timestamp_ms
  139. ))
  140. self.last_interaction_time = timestamp_ms
  141. def append_dialogue_history(self, message: Dict):
  142. self._uncommited_state_change.append(DialogueStateChange(
  143. DialogueStateChangeType.DIALOGUE_HISTORY,
  144. None,
  145. 1
  146. ))
  147. self.dialogue_history.append(message)
  148. def persist_state(self):
  149. """持久化对话状态,只有当前状态处理成功后才应该做持久化"""
  150. self.commit()
  151. config = configs.get()
  152. if config.get('debug_flags', {}).get('disable_database_write', False):
  153. return
  154. self.state_cache.set_state(self.staff_id, self.user_id, self.current_state, self.previous_state)
  155. def rollback_state(self):
  156. logger.info(f"staff[{self.staff_id}], user[{self.user_id}]: reverse state")
  157. for entry in reversed(self._uncommited_state_change):
  158. if entry.event_type == DialogueStateChangeType.STATE:
  159. self.current_state, self.previous_state = entry.old
  160. elif entry.event_type == DialogueStateChangeType.INTERACTION_TIME:
  161. self.last_interaction_time = entry.old
  162. elif entry.event_type == DialogueStateChangeType.DIALOGUE_HISTORY:
  163. self.dialogue_history.pop()
  164. else:
  165. logger.error(f"unimplemented type: [{entry.event_type}]")
  166. self._uncommited_state_change.clear()
  167. def commit(self):
  168. self._uncommited_state_change.clear()
  169. def do_state_change(self, state: DialogueState):
  170. state_backup = (self.current_state, self.previous_state)
  171. if self.current_state == DialogueState.MESSAGE_AGGREGATING:
  172. # MESSAGE_AGGREGATING不能成为previous_state,仅使用state_backup做回退
  173. self.current_state = state
  174. else:
  175. self.previous_state = self.current_state
  176. self.current_state = state
  177. self._uncommited_state_change.append(DialogueStateChange(
  178. DialogueStateChangeType.STATE,
  179. state_backup,
  180. (self.current_state, self.previous_state)
  181. ))
  182. def update_state(self, message: Message) -> Tuple[bool, Optional[str]]:
  183. """根据用户消息更新对话状态,并返回是否需要发起回复 及下一条需处理的用户消息"""
  184. message_text = message.content
  185. message_ts = message.sendTime
  186. # 如果当前已经是人工介入状态,保持该状态
  187. if self.current_state == DialogueState.HUMAN_INTERVENTION:
  188. # 记录对话历史,但不改变状态
  189. self.append_dialogue_history({
  190. "role": "user",
  191. "content": message_text,
  192. "timestamp": int(time.time() * 1000),
  193. "state": self.current_state.name
  194. })
  195. return False, message_text
  196. # 检查是否处于消息聚合状态
  197. if self.current_state == DialogueState.MESSAGE_AGGREGATING:
  198. # 收到的是特殊定时触发的空消息,且在聚合中,且已经超时,继续处理
  199. if message.type == MessageType.AGGREGATION_TRIGGER:
  200. if message_ts - self.last_interaction_time > self.message_aggregation_sec * 1000:
  201. logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: exit aggregation waiting")
  202. else:
  203. logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: continue aggregation waiting")
  204. return False, message_text
  205. else:
  206. # 非空消息,更新最后交互时间,保持消息聚合状态
  207. if message_text:
  208. self.unprocessed_messages.append(message_text)
  209. self.update_interaction_time(message_ts)
  210. return False, message_text
  211. else:
  212. if message.type == MessageType.AGGREGATION_TRIGGER:
  213. # 未在聚合状态中,收到的聚合触发消息为过时消息,不应当处理
  214. logger.warning(f"staff[{self.staff_id}], user[{self.user_id}]: received {message.type} in state {self.current_state}")
  215. return False, None
  216. if message.type != MessageType.AGGREGATION_TRIGGER and self.message_aggregation_sec > 0:
  217. # 收到有内容的用户消息,切换到消息聚合状态
  218. self.do_state_change(DialogueState.MESSAGE_AGGREGATING)
  219. self.unprocessed_messages.append(message_text)
  220. # 更新最后交互时间
  221. if message_text:
  222. self.update_interaction_time(message_ts)
  223. return False, message_text
  224. # 获得未处理的聚合消息,并清空未处理队列
  225. if message_text:
  226. self.unprocessed_messages.append(message_text)
  227. if self.unprocessed_messages:
  228. message_text = '\n'.join(self.unprocessed_messages)
  229. self.unprocessed_messages.clear()
  230. # 根据消息内容和当前状态确定新状态
  231. new_state = self._determine_state_from_message(message_text)
  232. # 处理连续澄清的情况
  233. if new_state == DialogueState.CLARIFICATION:
  234. self.consecutive_clarifications += 1
  235. # FIXME(zhoutian): 规则过于简单
  236. if self.consecutive_clarifications >= 10000:
  237. new_state = DialogueState.HUMAN_INTERVENTION
  238. # self._trigger_human_intervention("连续多次澄清请求")
  239. else:
  240. self.consecutive_clarifications = 0
  241. # 更新状态
  242. self.do_state_change(new_state)
  243. if message_text:
  244. self.update_interaction_time(message_ts)
  245. self.append_dialogue_history({
  246. "role": "user",
  247. "content": message_text,
  248. "timestamp": message_ts,
  249. "state": self.current_state.name
  250. })
  251. return True, message_text
  252. def _determine_state_from_message(self, message_text: Optional[str]) -> DialogueState:
  253. """根据消息内容确定对话状态"""
  254. if not message_text:
  255. logger.warning(f"staff[{self.staff_id}], user[{self.user_id}]: empty message")
  256. return self.current_state
  257. # 简单的规则-关键词匹配
  258. message_lower = message_text.lower()
  259. # 判断是否是复杂请求
  260. # FIXME(zhoutian): 规则过于简单
  261. # complex_request_keywords = ["帮我", "怎么办", "我需要", "麻烦你", "请帮助", "急", "紧急"]
  262. # if any(keyword in message_lower for keyword in complex_request_keywords):
  263. # self.complex_request_counter += 1
  264. #
  265. # # 如果检测到困难请求且计数达到阈值,触发人工介入
  266. # if self.complex_request_counter >= 1:
  267. # # self._trigger_human_intervention("检测到复杂请求")
  268. # return DialogueState.HUMAN_INTERVENTION
  269. # else:
  270. # # 如果不是复杂请求,重置计数器
  271. # self.complex_request_counter = 0
  272. # 问候检测
  273. greeting_keywords = ["你好", "早上好", "中午好", "晚上好", "嗨", "在吗"]
  274. if any(keyword in message_lower for keyword in greeting_keywords):
  275. return DialogueState.GREETING
  276. # 告别检测
  277. farewell_keywords = ["再见", "拜拜", "晚安", "明天见", "回头见"]
  278. if any(keyword in message_lower for keyword in farewell_keywords):
  279. return DialogueState.FAREWELL
  280. # 澄清请求
  281. # clarification_keywords = ["没明白", "不明白", "没听懂", "不懂", "什么意思", "再说一遍"]
  282. # if any(keyword in message_lower for keyword in clarification_keywords):
  283. # return DialogueState.CLARIFICATION
  284. # 默认为闲聊状态
  285. return DialogueState.CHITCHAT
  286. def _trigger_human_intervention(self, reason: str) -> None:
  287. """触发人工介入"""
  288. if not self.human_intervention_triggered:
  289. self.human_intervention_triggered = True
  290. # 记录人工介入事件
  291. event = {
  292. "timestamp": int(time.time() * 1000),
  293. "reason": reason,
  294. "dialogue_context": self.dialogue_history[-10:]
  295. }
  296. # 更新用户资料中的人工介入历史
  297. if "human_intervention_history" not in self.user_profile:
  298. self.user_profile["human_intervention_history"] = []
  299. self.user_profile["human_intervention_history"].append(event)
  300. self.user_manager.save_user_profile(self.user_id, self.user_profile)
  301. # 发送告警
  302. self._send_human_intervention_alert(reason)
  303. def _send_human_intervention_alert(self, reason: str) -> None:
  304. alert_message = f"""
  305. 人工介入告警
  306. 用户ID: {self.user_id}
  307. 用户昵称: {self.user_profile.get("nickname", "未知")}
  308. 时间: {int(time.time() * 1000)}
  309. 原因: {reason}
  310. 最近对话:
  311. """
  312. # 添加最近的对话记录
  313. recent_dialogues = self.dialogue_history[-10:]
  314. for dialogue in recent_dialogues:
  315. alert_message += f"\n{dialogue['role']}: {dialogue['content']}"
  316. # TODO(zhoutian): 实现发送告警的具体逻辑
  317. logger.warning(alert_message)
  318. def resume_from_human_intervention(self) -> None:
  319. """从人工介入状态恢复"""
  320. if self.current_state == DialogueState.HUMAN_INTERVENTION:
  321. self.do_state_change(DialogueState.CHITCHAT)
  322. self.human_intervention_triggered = False
  323. self.consecutive_clarifications = 0
  324. self.complex_request_counter = 0
  325. # 记录恢复事件
  326. self.append_dialogue_history({
  327. "role": "system",
  328. "content": "已从人工介入状态恢复到自动对话",
  329. "timestamp": int(time.time() * 1000),
  330. "state": self.current_state.name
  331. })
  332. def generate_response(self, llm_response: str) -> Optional[str]:
  333. """根据当前状态处理LLM响应,如果处于人工介入状态则返回None"""
  334. # 如果处于人工介入状态,不生成回复
  335. if self.current_state == DialogueState.HUMAN_INTERVENTION:
  336. return None
  337. # 记录响应到对话历史
  338. message_ts = int(time.time() * 1000)
  339. self.append_dialogue_history({
  340. "role": "assistant",
  341. "content": llm_response,
  342. "timestamp": message_ts,
  343. "state": self.current_state.name
  344. })
  345. self.update_interaction_time(message_ts)
  346. return llm_response
  347. def _get_hours_since_last_interaction(self, precision: int = -1):
  348. time_diff = (time.time() * 1000) - self.last_interaction_time
  349. hours_passed = time_diff / 1000 / 3600
  350. if precision >= 0:
  351. return round(hours_passed, precision)
  352. return hours_passed
  353. def should_initiate_conversation(self) -> bool:
  354. """判断是否应该主动发起对话"""
  355. # 如果处于人工介入状态,不应主动发起对话
  356. if self.current_state == DialogueState.HUMAN_INTERVENTION:
  357. return False
  358. hours_passed = self._get_hours_since_last_interaction()
  359. # 获取当前时间上下文
  360. time_context = self.get_time_context()
  361. # 根据用户交互频率偏好设置不同的阈值
  362. interaction_frequency = self.user_profile.get("interaction_frequency", "medium")
  363. if interaction_frequency == 'stopped':
  364. return False
  365. # 设置不同偏好的交互时间阈值(小时)
  366. thresholds = {
  367. "low": 24, # 低频率:一天一次
  368. "medium": 12, # 中频率:半天一次
  369. "high": 6 # 高频率:大约6小时一次
  370. }
  371. threshold = thresholds.get(interaction_frequency, 12)
  372. if hours_passed < threshold:
  373. return False
  374. # 根据时间上下文决定主动交互的状态
  375. if time_context in [TimeContext.MORNING,
  376. TimeContext.NOON, TimeContext.AFTERNOON]:
  377. return True
  378. return False
  379. def is_in_human_intervention(self) -> bool:
  380. """检查是否处于人工介入状态"""
  381. return self.current_state == DialogueState.HUMAN_INTERVENTION
  382. def get_prompt_context(self, user_message) -> Dict:
  383. # 获取当前时间上下文
  384. time_context = self.get_time_context()
  385. # 刷新用户画像
  386. self.user_profile = self.user_manager.get_user_profile(self.user_id)
  387. # 刷新员工画像(不一定需要)
  388. self.staff_profile = self.user_manager.get_staff_profile(self.staff_id)
  389. context = {
  390. "user_profile": self.user_profile,
  391. "current_state": self.current_state.name,
  392. "previous_state": self.previous_state.name,
  393. "current_time_period": time_context.description,
  394. "current_hour": datetime.now().hour,
  395. "last_interaction_interval": self._get_hours_since_last_interaction(2),
  396. "if_first_interaction": True if self.previous_state == DialogueState.INITIALIZED else False,
  397. "if_active_greeting": False if user_message else True,
  398. **self.user_profile,
  399. **self.staff_profile
  400. }
  401. # 获取长期记忆
  402. relevant_memories = self.vector_memory.retrieve_relevant_memories(user_message)
  403. context["long_term_memory"] = {
  404. "relevant_conversations": relevant_memories
  405. }
  406. return context
  407. @staticmethod
  408. def _select_prompt(state):
  409. state_to_prompt_map = {
  410. DialogueState.GREETING: GENERAL_GREETING_PROMPT,
  411. DialogueState.CHITCHAT: CHITCHAT_PROMPT_COZE,
  412. DialogueState.FAREWELL: GENERAL_GREETING_PROMPT
  413. }
  414. return state_to_prompt_map[state]
  415. @staticmethod
  416. def _select_coze_bot(state, dialogue: List[Dict], multimodal=False):
  417. state_to_bot_map = {
  418. DialogueState.GREETING: '7486112546798780425',
  419. DialogueState.CHITCHAT: '7491300566573301770',
  420. DialogueState.FAREWELL: '7491300566573301770',
  421. }
  422. if multimodal:
  423. state_to_bot_map = {
  424. DialogueState.GREETING: '7496772218198900770',
  425. DialogueState.CHITCHAT: '7495692989504438308',
  426. DialogueState.FAREWELL: '7491300566573301770',
  427. }
  428. return state_to_bot_map[state]
  429. @staticmethod
  430. def need_multimodal_model(dialogue: List[Dict], max_message_to_use: int = 10):
  431. # 当前仅为简单实现
  432. recent_messages = dialogue[-max_message_to_use:]
  433. ret = False
  434. for entry in recent_messages:
  435. if entry.get('type') in (MessageType.IMAGE_GW, MessageType.IMAGE_QW, MessageType.GIF):
  436. ret = True
  437. break
  438. return ret
  439. def _create_system_message(self, prompt_context):
  440. prompt_template = self._select_prompt(self.current_state)
  441. prompt = prompt_template.format(**prompt_context)
  442. return {'role': 'system', 'content': prompt}
  443. @staticmethod
  444. def compose_chat_messages_openai_compatible(dialogue_history, current_time, multimodal=False):
  445. messages = []
  446. for entry in dialogue_history:
  447. role = entry['role']
  448. msg_type = entry.get('type', MessageType.TEXT)
  449. fmt_time = DialogueManager.format_timestamp(entry['timestamp'])
  450. if msg_type in (MessageType.IMAGE_GW, MessageType.IMAGE_QW, MessageType.GIF):
  451. if multimodal:
  452. messages.append({
  453. "role": role,
  454. "content": [
  455. {"type": "image_url", "image_url": {"url": entry["content"]}}
  456. ]
  457. })
  458. else:
  459. logger.warning("Image in non-multimodal mode")
  460. messages.append({
  461. "role": role,
  462. "content": "[{}] {}".format(fmt_time, '[图片]')
  463. })
  464. else:
  465. messages.append({
  466. "role": role,
  467. "content": '[{}] {}'.format(fmt_time, entry["content"])
  468. })
  469. # 添加一条前缀用于 约束时间场景
  470. msg_prefix = '[{}]'.format(current_time)
  471. messages.append({'role': 'assistant', 'content': msg_prefix})
  472. return messages
  473. @staticmethod
  474. def compose_chat_messages_coze(dialogue_history, current_time, staff_id, user_id):
  475. messages = []
  476. # 如果system后的第1条消息不为user,需要在最开始补一条user消息,否则会吞assistant消息
  477. if len(dialogue_history) > 0 and dialogue_history[0]['role'] != 'user':
  478. fmt_time = DialogueManager.format_timestamp(dialogue_history[0]['timestamp'])
  479. messages.append(cozepy.Message.build_user_question_text(f'[{fmt_time}] '))
  480. # coze最后一条消息必须为user,且可能吞掉连续的user消息,故强制增加一条空消息(可参与合并)
  481. dialogue_history.append({
  482. 'role': 'user',
  483. 'content': ' ',
  484. 'timestamp': int(datetime.strptime(current_time, '%Y-%m-%d %H:%M:%S').timestamp() * 1000),
  485. })
  486. # 将连续的同一角色的消息做聚合,避免coze吞消息
  487. messages_to_aggr = []
  488. objects_to_aggr = []
  489. last_message_role = None
  490. for entry in dialogue_history:
  491. if not entry['content']:
  492. logger.warning("staff[{}], user[{}], role[{}]: empty content in dialogue history".format(
  493. staff_id, user_id, entry['role']
  494. ))
  495. continue
  496. role = entry['role']
  497. if role != last_message_role:
  498. if objects_to_aggr:
  499. if last_message_role != 'user':
  500. pass
  501. else:
  502. text_message = '\n'.join(messages_to_aggr)
  503. object_string_list = []
  504. for object_entry in objects_to_aggr:
  505. # FIXME: 其它消息类型的支持
  506. object_string_list.append(cozepy.MessageObjectString.build_image(file_url=object_entry['content']))
  507. object_string_list.append(cozepy.MessageObjectString.build_text(text_message))
  508. messages.append(cozepy.Message.build_user_question_objects(object_string_list))
  509. elif messages_to_aggr:
  510. aggregated_message = '\n'.join(messages_to_aggr)
  511. messages.append(DialogueManager.build_chat_message(
  512. last_message_role, aggregated_message, ChatServiceType.COZE_CHAT))
  513. objects_to_aggr = []
  514. messages_to_aggr = []
  515. last_message_role = role
  516. if entry.get('type', MessageType.TEXT) in (MessageType.IMAGE_GW, MessageType.IMAGE_QW, MessageType.GIF):
  517. # 多模态消息必须用特殊的聚合方式,一个object_string数组中只能有一个文字消息,但可以有多个图片
  518. if role == 'user':
  519. objects_to_aggr.append(entry)
  520. else:
  521. logger.warning("staff[{}], user[{}]: unsupported message type [{}] in assistant role".format(
  522. staff_id, user_id, entry['type']
  523. ))
  524. else:
  525. messages_to_aggr.append(DialogueManager.format_dialogue_content(entry))
  526. # 如果有未聚合的object消息,需要特殊处理
  527. if objects_to_aggr:
  528. if last_message_role != 'user':
  529. pass
  530. else:
  531. text_message = '\n'.join(messages_to_aggr)
  532. object_string_list = []
  533. for object_entry in objects_to_aggr:
  534. # FIXME: 其它消息类型的支持
  535. object_string_list.append(cozepy.MessageObjectString.build_image(file_url=object_entry['content']))
  536. object_string_list.append(cozepy.MessageObjectString.build_text(text_message))
  537. messages.append(cozepy.Message.build_user_question_objects(object_string_list))
  538. elif messages_to_aggr:
  539. aggregated_message = '\n'.join(messages_to_aggr)
  540. messages.append(DialogueManager.build_chat_message(
  541. last_message_role, aggregated_message, ChatServiceType.COZE_CHAT))
  542. return messages
  543. def build_active_greeting_config(self, user_tags: List[str]):
  544. # FIXME: 这里的抽象不好,短期支持人为配置实验
  545. chat_config = {'user_id': self.user_id}
  546. prompt_context = self.get_prompt_context(None)
  547. current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  548. system_message = {'role': 'system', 'content': 'You are a helpful AI assistant.'}
  549. # TODO: 随机选择一个prompt 或 带策略选择 或根据用户标签选择
  550. # TODO:需要区分用户是否有历史交互、是否发送过相似内容
  551. greeting_prompts = [
  552. prompt_templates.GREETING_WITH_IMAGE_GAME,
  553. prompt_templates.GREETING_WITH_NAME_POETRY,
  554. prompt_templates.GREETING_WITH_AVATAR_STORY
  555. ]
  556. # 默认随机选择
  557. selected_prompt = greeting_prompts[random.randint(0, len(greeting_prompts) - 1)]
  558. # 实验配置
  559. tag_to_greeting_map = {
  560. '04W4-AA-1': prompt_templates.GREETING_WITH_NAME_POETRY,
  561. '04W4-AA-2': prompt_templates.GREETING_WITH_AVATAR_STORY,
  562. '04W4-AA-3': prompt_templates.GREETING_WITH_AVATAR_STORY,
  563. '04W4-AA-4': prompt_templates.GREETING_WITH_IMAGE_GAME,
  564. }
  565. for tag in user_tags:
  566. if tag in tag_to_greeting_map:
  567. selected_prompt = tag_to_greeting_map[tag]
  568. prompt = selected_prompt.format(**prompt_context)
  569. user_message = {'role': 'user', 'content': prompt}
  570. messages = [system_message, user_message]
  571. if selected_prompt == prompt_templates.GREETING_WITH_AVATAR_STORY:
  572. messages.append({
  573. "role": 'user',
  574. "content": [
  575. {"type": "image_url", "image_url": {"url": self.user_profile['avatar']}}
  576. ]
  577. })
  578. chat_config['use_multimodal_model'] = True
  579. chat_config['messages'] = messages
  580. return chat_config
  581. def build_chat_configuration(
  582. self,
  583. user_message: Optional[str] = None,
  584. chat_service_type: ChatServiceType = ChatServiceType.OPENAI_COMPATIBLE,
  585. overwrite_context: Optional[Dict] = None
  586. ) -> Dict:
  587. """
  588. 参数:
  589. user_message: 当前用户消息,如果是主动交互则为None
  590. 返回:
  591. 消息列表
  592. """
  593. dialogue_history = self.history_dialogue_service.get_dialogue_history(self.staff_id, self.user_id)
  594. logger.debug("staff[{}], user[{}], dialogue_history: {}".format(
  595. self.staff_id, self.user_id, dialogue_history
  596. ))
  597. messages = []
  598. config = {
  599. 'user_id': self.user_id
  600. }
  601. prompt_context = self.get_prompt_context(user_message)
  602. if overwrite_context:
  603. prompt_context.update(overwrite_context)
  604. # FIXME(zhoutian): time in string type
  605. current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  606. if overwrite_context and 'current_time' in overwrite_context:
  607. current_time = overwrite_context.get('current_time')
  608. need_multimodal = self.need_multimodal_model(dialogue_history)
  609. config['use_multimodal_model'] = need_multimodal
  610. if chat_service_type == ChatServiceType.OPENAI_COMPATIBLE:
  611. system_message = self._create_system_message(prompt_context)
  612. messages.append(system_message)
  613. messages.extend(self.compose_chat_messages_openai_compatible(dialogue_history, current_time, need_multimodal))
  614. elif chat_service_type == ChatServiceType.COZE_CHAT:
  615. dialogue_history = dialogue_history[-95:] # Coze最多支持100条,还需要附加系统消息
  616. messages = self.compose_chat_messages_coze(dialogue_history, current_time, self.staff_id, self.user_id)
  617. custom_variables = {}
  618. for k, v in prompt_context.items():
  619. custom_variables[k] = str(v)
  620. custom_variables.pop('user_profile', None)
  621. config['custom_variables'] = custom_variables
  622. config['bot_id'] = self._select_coze_bot(self.current_state, dialogue_history, need_multimodal)
  623. #FIXME(zhoutian): 临时报警
  624. if user_message and not messages:
  625. logger.error(f"staff[{self.staff_id}], user[{self.user_id}]: inconsistency in messages")
  626. config['messages'] = messages
  627. return config
  628. @staticmethod
  629. def format_timestamp(timestamp_ms):
  630. return datetime.fromtimestamp(timestamp_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
  631. @staticmethod
  632. def format_dialogue_content(dialogue_entry):
  633. fmt_time = DialogueManager.format_timestamp(dialogue_entry['timestamp'])
  634. content = '[{}] {}'.format(fmt_time, dialogue_entry['content'])
  635. return content
  636. @staticmethod
  637. def build_chat_message(role, content, chat_service_type: ChatServiceType):
  638. if chat_service_type == ChatServiceType.COZE_CHAT:
  639. if role == 'user':
  640. return cozepy.Message.build_user_question_text(content)
  641. elif role == 'assistant':
  642. return cozepy.Message.build_assistant_answer(content)
  643. else:
  644. return {'role': role, 'content': content}
  645. if __name__ == '__main__':
  646. state_cache = DialogueStateCache()
  647. state_cache.set_state('1688854492669990', '7881302581935903', DialogueState.CHITCHAT, DialogueState.GREETING)