|
- #! /usr/bin/env python
- # -*- coding: utf-8 -*-
- # vim:fenc=utf-8
- import random
- from enum import Enum, auto
- from typing import Dict, List, Optional, Tuple, Any
- from datetime import datetime
- import time
- import textwrap
- import chat_service
- import prompt_templates
- from logging_service import logger
- import pymysql.cursors
- import configs
- import cozepy
- from database import MySQLManager
- from history_dialogue_service import HistoryDialogueService
- from chat_service import ChatServiceType
- from message import MessageType, Message
- from toolkit.lark_alert_for_human_intervention import LarkAlertForHumanIntervention
- from user_manager import UserManager
- from prompt_templates import *
- class DummyVectorMemoryManager:
- def __init__(self, user_id):
- pass
- def add_to_memory(self, conversation):
- pass
- def retrieve_relevant_memories(self, query, k=3):
- return []
- class DialogueState(int, Enum):
- INITIALIZED = 0
- GREETING = 1 # 问候状态
- CHITCHAT = 2 # 闲聊状态
- CLARIFICATION = 3 # 澄清状态
- FAREWELL = 4 # 告别状态
- HUMAN_INTERVENTION = 5 # 人工介入状态
- MESSAGE_AGGREGATING = 6 # 等待消息状态
- class TimeContext(Enum):
- EARLY_MORNING = "清晨" # 清晨 (5:00-7:59)
- MORNING = "上午" # 上午 (8:00-11:59)
- NOON = "中午" # 中午 (12:00-13:59)
- AFTERNOON = "下午" # 下午 (14:00-17:59)
- EVENING = "晚上" # 晚上 (18:00-21:59)
- NIGHT = "深夜" # 夜晚 (22:00-4:59)
- def __init__(self, description):
- self.description = description
- class DialogueStateChangeType(int, Enum):
- STATE = 0
- INTERACTION_TIME = 1
- DIALOGUE_HISTORY = 2
- class DialogueStateChange:
- def __init__(self, event_type: DialogueStateChangeType,old: Any, new: Any):
- self.event_type = event_type
- self.old = old
- self.new = new
- class DialogueStateCache:
- def __init__(self):
- self.config = configs.get()
- self.db = MySQLManager(self.config['storage']['agent_state']['mysql'])
- self.table = self.config['storage']['agent_state']['table']
- def get_state(self, staff_id: str, user_id: str) -> Tuple[DialogueState, DialogueState]:
- query = f"SELECT current_state, previous_state FROM {self.table} WHERE staff_id=%s AND user_id=%s"
- data = self.db.select(query, pymysql.cursors.DictCursor, (staff_id, user_id))
- if not data:
- logger.warning(f"staff[{staff_id}], user[{user_id}]: agent state not found")
- state = DialogueState.INITIALIZED
- previous_state = DialogueState.INITIALIZED
- self.set_state(staff_id, user_id, state, previous_state)
- else:
- state = DialogueState(data[0]['current_state'])
- previous_state = DialogueState(data[0]['previous_state'])
- return state, previous_state
- def set_state(self, staff_id: str, user_id: str, state: DialogueState, previous_state: DialogueState):
- if self.config.get('debug_flags', {}).get('disable_database_write', False):
- return
- query = f"INSERT INTO {self.table} (staff_id, user_id, current_state, previous_state)" \
- f" VALUES (%s, %s, %s, %s) " \
- f"ON DUPLICATE KEY UPDATE current_state=%s, previous_state=%s"
- rows = self.db.execute(query, (staff_id, user_id, state.value, previous_state.value, state.value, previous_state.value))
- logger.debug("staff[{}], user[{}]: set state: {}, previous state: {}, rows affected: {}"
- .format(staff_id, user_id, state, previous_state, rows))
- class DialogueManager:
- def __init__(self, staff_id: str, user_id: str, user_manager: UserManager, state_cache: DialogueStateCache):
- config = configs.get()
- self.staff_id = staff_id
- self.user_id = user_id
- self.user_manager = user_manager
- self.state_cache = state_cache
- self.current_state = DialogueState.GREETING
- self.previous_state = DialogueState.INITIALIZED
- # 目前实际仅用作调试,拼装prompt时使用history_dialogue_service获取
- self.dialogue_history = []
- self.user_profile = self.user_manager.get_user_profile(user_id)
- self.staff_profile = self.user_manager.get_staff_profile(staff_id)
- # FIXME: 交互时间和对话记录都涉及到回滚
- self.last_interaction_time = 0
- self.consecutive_clarifications = 0
- self.complex_request_counter = 0
- self.human_intervention_triggered = False
- self.vector_memory = DummyVectorMemoryManager(user_id)
- self.message_aggregation_sec = config.get('agent_behavior', {}).get('message_aggregation_sec', 5)
- self.unprocessed_messages = []
- self.history_dialogue_service = HistoryDialogueService(
- config['storage']['history_dialogue']['api_base_url']
- )
- self._recover_state()
- # 由于本地状态管理过于复杂,引入事务机制做状态回滚
- self._uncommited_state_change = []
- @staticmethod
- def get_time_context(current_hour=None) -> TimeContext:
- """获取当前时间上下文"""
- if not current_hour:
- current_hour = datetime.now().hour
- if 5 <= current_hour < 7:
- return TimeContext.EARLY_MORNING
- elif 7 <= current_hour < 11:
- return TimeContext.MORNING
- elif 11 <= current_hour < 14:
- return TimeContext.NOON
- elif 14 <= current_hour < 18:
- return TimeContext.AFTERNOON
- elif 18 <= current_hour < 22:
- return TimeContext.EVENING
- else:
- return TimeContext.NIGHT
- def _recover_state(self):
- self.current_state, self.previous_state = self.state_cache.get_state(self.staff_id, self.user_id)
- # 从数据库恢复对话状态
- self.dialogue_history = self.history_dialogue_service.get_dialogue_history(self.staff_id, self.user_id)
- if self.dialogue_history:
- self.last_interaction_time = self.dialogue_history[-1]['timestamp']
- if self.current_state == DialogueState.MESSAGE_AGGREGATING:
- # 需要恢复未处理对话,找到dialogue_history中最后未处理的user消息
- for entry in reversed(self.dialogue_history):
- if entry['role'] == 'user':
- self.unprocessed_messages.append(entry['content'])
- break
- else:
- # 默认设置为24小时前
- self.last_interaction_time = int(time.time() * 1000) - 24 * 3600 * 1000
- time_for_read = datetime.fromtimestamp(self.last_interaction_time / 1000).strftime("%Y-%m-%d %H:%M:%S")
- logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: state: {self.current_state.name}, last_interaction: {time_for_read}")
- def update_interaction_time(self, timestamp_ms: int):
- self._uncommited_state_change.append(DialogueStateChange(
- DialogueStateChangeType.INTERACTION_TIME,
- self.last_interaction_time,
- timestamp_ms
- ))
- self.last_interaction_time = timestamp_ms
- def append_dialogue_history(self, message: Dict):
- self._uncommited_state_change.append(DialogueStateChange(
- DialogueStateChangeType.DIALOGUE_HISTORY,
- None,
- 1
- ))
- self.dialogue_history.append(message)
- def persist_state(self):
- """持久化对话状态,只有当前状态处理成功后才应该做持久化"""
- self.commit()
- config = configs.get()
- if config.get('debug_flags', {}).get('disable_database_write', False):
- return
- self.state_cache.set_state(self.staff_id, self.user_id, self.current_state, self.previous_state)
- def rollback_state(self):
- logger.info(f"staff[{self.staff_id}], user[{self.user_id}]: reverse state")
- for entry in reversed(self._uncommited_state_change):
- if entry.event_type == DialogueStateChangeType.STATE:
- self.current_state, self.previous_state = entry.old
- elif entry.event_type == DialogueStateChangeType.INTERACTION_TIME:
- self.last_interaction_time = entry.old
- elif entry.event_type == DialogueStateChangeType.DIALOGUE_HISTORY:
- self.dialogue_history.pop()
- else:
- logger.error(f"unimplemented type: [{entry.event_type}]")
- self._uncommited_state_change.clear()
- def commit(self):
- self._uncommited_state_change.clear()
- def do_state_change(self, state: DialogueState):
- state_backup = (self.current_state, self.previous_state)
- if self.current_state == DialogueState.MESSAGE_AGGREGATING:
- # MESSAGE_AGGREGATING不能成为previous_state,仅使用state_backup做回退
- self.current_state = state
- else:
- self.previous_state = self.current_state
- self.current_state = state
- self._uncommited_state_change.append(DialogueStateChange(
- DialogueStateChangeType.STATE,
- state_backup,
- (self.current_state, self.previous_state)
- ))
- def update_state(self, message: Message) -> Tuple[bool, Optional[str]]:
- """根据用户消息更新对话状态,并返回是否需要发起回复 及下一条需处理的用户消息"""
- message_text = message.content
- message_ts = message.sendTime
- # 如果当前已经是人工介入状态,保持该状态
- if self.current_state == DialogueState.HUMAN_INTERVENTION:
- # 记录对话历史,但不改变状态
- self.append_dialogue_history({
- "role": "user",
- "content": message_text,
- "timestamp": int(time.time() * 1000),
- "state": self.current_state.name
- })
- return False, message_text
- # 检查是否处于消息聚合状态
- if self.current_state == DialogueState.MESSAGE_AGGREGATING:
- # 收到的是特殊定时触发的空消息,且在聚合中,且已经超时,继续处理
- if message.type == MessageType.AGGREGATION_TRIGGER:
- if message_ts - self.last_interaction_time > self.message_aggregation_sec * 1000:
- logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: exit aggregation waiting")
- else:
- logger.debug(f"staff[{self.staff_id}], user[{self.user_id}]: continue aggregation waiting")
- return False, message_text
- else:
- # 非空消息,更新最后交互时间,保持消息聚合状态
- if message_text:
- self.unprocessed_messages.append(message_text)
- self.update_interaction_time(message_ts)
- return False, message_text
- else:
- if message.type == MessageType.AGGREGATION_TRIGGER:
- # 未在聚合状态中,收到的聚合触发消息为过时消息,不应当处理
- logger.warning(f"staff[{self.staff_id}], user[{self.user_id}]: received {message.type} in state {self.current_state}")
- return False, None
- if message.type != MessageType.AGGREGATION_TRIGGER and self.message_aggregation_sec > 0:
- # 收到有内容的用户消息,切换到消息聚合状态
- self.do_state_change(DialogueState.MESSAGE_AGGREGATING)
- self.unprocessed_messages.append(message_text)
- # 更新最后交互时间
- if message_text:
- self.update_interaction_time(message_ts)
- return False, message_text
- # 获得未处理的聚合消息,并清空未处理队列
- if message_text:
- self.unprocessed_messages.append(message_text)
- if self.unprocessed_messages:
- message_text = '\n'.join(self.unprocessed_messages)
- self.unprocessed_messages.clear()
- # 实际上这里message_text并不会被最终送入LLM,只是用来做状态判断
- # 根据消息内容和当前状态确定新状态
- new_state = self._determine_state_from_message(message_text)
- # 处理连续澄清的情况
- if new_state == DialogueState.CLARIFICATION:
- self.consecutive_clarifications += 1
- # FIXME(zhoutian): 规则过于简单
- if self.consecutive_clarifications >= 10000:
- new_state = DialogueState.HUMAN_INTERVENTION
- # self._trigger_human_intervention("连续多次澄清请求")
- else:
- self.consecutive_clarifications = 0
- # 更新状态
- self.do_state_change(new_state)
- if message_text:
- self.update_interaction_time(message_ts)
- self.append_dialogue_history({
- "role": "user",
- "content": message_text,
- "timestamp": message_ts,
- "state": self.current_state.name
- })
- return True, message_text
- def _determine_state_from_message(self, message_text: Optional[str]) -> DialogueState:
- """根据消息内容确定对话状态"""
- if not message_text:
- logger.warning(f"staff[{self.staff_id}], user[{self.user_id}]: empty message")
- return self.current_state
- # 简单的规则-关键词匹配
- message_lower = message_text.lower()
- # 判断是否是复杂请求
- # FIXME(zhoutian): 规则过于简单
- # complex_request_keywords = ["帮我", "怎么办", "我需要", "麻烦你", "请帮助", "急", "紧急"]
- # if any(keyword in message_lower for keyword in complex_request_keywords):
- # self.complex_request_counter += 1
- #
- # # 如果检测到困难请求且计数达到阈值,触发人工介入
- # if self.complex_request_counter >= 1:
- # # self._trigger_human_intervention("检测到复杂请求")
- # return DialogueState.HUMAN_INTERVENTION
- # else:
- # # 如果不是复杂请求,重置计数器
- # self.complex_request_counter = 0
- # 问候检测
- greeting_keywords = ["你好", "早上好", "中午好", "晚上好", "嗨", "在吗"]
- if any(keyword in message_lower for keyword in greeting_keywords):
- return DialogueState.GREETING
- # 告别检测
- farewell_keywords = ["再见", "拜拜", "晚安", "明天见", "回头见"]
- if any(keyword in message_lower for keyword in farewell_keywords):
- return DialogueState.FAREWELL
- # 澄清请求
- # clarification_keywords = ["没明白", "不明白", "没听懂", "不懂", "什么意思", "再说一遍"]
- # if any(keyword in message_lower for keyword in clarification_keywords):
- # return DialogueState.CLARIFICATION
- # 默认为闲聊状态
- return DialogueState.CHITCHAT
- def _trigger_human_intervention(self, reason: str) -> None:
- """触发人工介入"""
- # 记录人工介入事件
- # FIXME: 重启即丢失
- event = {
- "timestamp": int(time.time() * 1000),
- "reason": reason,
- "dialogue_context": self.dialogue_history[-10:]
- }
- # 更新用户资料中的人工介入历史
- if "human_intervention_history" not in self.user_profile:
- self.user_profile["human_intervention_history"] = []
- self.user_profile["human_intervention_history"].append(event)
- self.user_manager.save_user_profile(self.user_id, self.user_profile)
- # 发送告警
- self._send_human_intervention_alert()
- def _send_human_intervention_alert(self) -> None:
- time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- alert_message = f"""
- 人工介入告警
- 员工: {self.staff_profile.get("agent_name", "未知")}[{self.staff_id}]
- 用户: {self.user_profile.get("nickname", "未知")}[{self.user_id}]
- 时间: {time_str}
- 最近对话:"""
- alert_message = textwrap.dedent(alert_message)
- # 添加最近的对话记录
- recent_dialogues = self.dialogue_history[-5:]
- role_map = {'assistant': '客服', 'user': '用户'}
- for dialogue in recent_dialogues:
- if not dialogue['content']:
- continue
- role = dialogue['role']
- if role not in role_map:
- continue
- alert_message += f"\n[{role_map[role]}]{dialogue['content']}"
- LarkAlertForHumanIntervention().send_lark_alert_for_human_intervention(alert_message)
- def resume_from_human_intervention(self) -> None:
- """从人工介入状态恢复"""
- if self.current_state == DialogueState.HUMAN_INTERVENTION:
- self.do_state_change(DialogueState.CHITCHAT)
- self.consecutive_clarifications = 0
- self.complex_request_counter = 0
- # 记录恢复事件
- self.append_dialogue_history({
- "role": "system",
- "content": "已从人工介入状态恢复到自动对话",
- "timestamp": int(time.time() * 1000),
- "state": self.current_state.name
- })
- def generate_response(self, llm_response: str) -> Optional[str]:
- """
- 处理LLM的响应,更新对话状态和对话历史。
- 注意:所有的LLM响应都必须经过这个函数来处理!
- :param llm_response:
- :return:
- """
- if '<人工介入>' in llm_response:
- logger.warning(f'staff[{self.staff_id}], user[{self.user_id}]: human intervention triggered')
- self.do_state_change(DialogueState.HUMAN_INTERVENTION)
- self._send_human_intervention_alert()
- return None
- """根据当前状态处理LLM响应,如果处于人工介入状态则返回None"""
- # 如果处于人工介入状态,不生成回复
- if self.current_state == DialogueState.HUMAN_INTERVENTION:
- return None
- # 记录响应到对话历史
- message_ts = int(time.time() * 1000)
- self.append_dialogue_history({
- "role": "assistant",
- "content": llm_response,
- "timestamp": message_ts,
- "state": self.current_state.name
- })
- self.update_interaction_time(message_ts)
- return llm_response
- def _get_hours_since_last_interaction(self, precision: int = -1):
- time_diff = (time.time() * 1000) - self.last_interaction_time
- hours_passed = time_diff / 1000 / 3600
- if precision >= 0:
- return round(hours_passed, precision)
- return hours_passed
- def should_initiate_conversation(self) -> bool:
- """判断是否应该主动发起对话"""
- # 如果处于人工介入状态,不应主动发起对话
- if self.current_state == DialogueState.HUMAN_INTERVENTION:
- return False
- hours_passed = self._get_hours_since_last_interaction()
- # 获取当前时间上下文
- time_context = self.get_time_context()
- # 根据用户交互频率偏好设置不同的阈值
- interaction_frequency = self.user_profile.get("interaction_frequency", "medium")
- if interaction_frequency == 'stopped':
- return False
- # 设置不同偏好的交互时间阈值(小时)
- thresholds = {
- "low": 48,
- "medium": 24,
- "high": 12
- }
- threshold = thresholds.get(interaction_frequency, 12)
- if hours_passed < threshold:
- return False
- # 根据时间上下文决定主动交互的状态
- if self.is_time_suitable_for_active_conversation(time_context):
- return True
- return False
- @staticmethod
- def is_time_suitable_for_active_conversation(time_context=None) -> bool:
- if not time_context:
- time_context = DialogueManager.get_time_context()
- if time_context in [TimeContext.MORNING,
- TimeContext.NOON, TimeContext.AFTERNOON]:
- return True
- return False
- def is_in_human_intervention(self) -> bool:
- """检查是否处于人工介入状态"""
- return self.current_state == DialogueState.HUMAN_INTERVENTION
- def get_prompt_context(self, user_message) -> Dict:
- # 获取当前时间上下文
- time_context = self.get_time_context()
- # 刷新用户画像
- self.user_profile = self.user_manager.get_user_profile(self.user_id)
- # 刷新员工画像(不一定需要)
- self.staff_profile = self.user_manager.get_staff_profile(self.staff_id)
- current_datetime = datetime.now()
- context = {
- "user_profile": self.user_profile,
- "current_state": self.current_state.name,
- "previous_state": self.previous_state.name,
- "current_time_period": time_context.description,
- "current_hour": current_datetime.hour,
- "current_time": current_datetime.strftime("%H:%M:%S"),
- "current_date": current_datetime.strftime("%Y-%m-%d"),
- "last_interaction_interval": self._get_hours_since_last_interaction(2),
- "if_first_interaction": True if self.previous_state == DialogueState.INITIALIZED else False,
- "if_active_greeting": False if user_message else True,
- **self.user_profile,
- **self.staff_profile
- }
- # 获取长期记忆
- relevant_memories = self.vector_memory.retrieve_relevant_memories(user_message)
- context["long_term_memory"] = {
- "relevant_conversations": relevant_memories
- }
- return context
- @staticmethod
- def _select_prompt(state):
- state_to_prompt_map = {
- DialogueState.GREETING: GENERAL_GREETING_PROMPT,
- DialogueState.CHITCHAT: CHITCHAT_PROMPT_COZE,
- DialogueState.FAREWELL: GENERAL_GREETING_PROMPT
- }
- return state_to_prompt_map[state]
- @staticmethod
- def _select_coze_bot(state, dialogue: List[Dict], multimodal=False):
- state_to_bot_map = {
- DialogueState.GREETING: '7486112546798780425',
- DialogueState.CHITCHAT: '7491300566573301770',
- DialogueState.FAREWELL: '7491300566573301770',
- }
- if multimodal:
- state_to_bot_map = {
- DialogueState.GREETING: '7496772218198900770',
- DialogueState.CHITCHAT: '7495692989504438308',
- DialogueState.FAREWELL: '7491300566573301770',
- }
- return state_to_bot_map[state]
- @staticmethod
- def need_multimodal_model(dialogue: List[Dict], max_message_to_use: int = 10):
- # 当前仅为简单实现
- recent_messages = dialogue[-max_message_to_use:]
- ret = False
- for entry in recent_messages:
- if entry.get('type') in (MessageType.IMAGE_GW, MessageType.IMAGE_QW, MessageType.GIF):
- ret = True
- break
- return ret
- def _create_system_message(self, prompt_context):
- prompt_template = self._select_prompt(self.current_state)
- prompt = prompt_template.format(**prompt_context)
- return {'role': 'system', 'content': prompt}
- @staticmethod
- def compose_chat_messages_openai_compatible(dialogue_history, current_time, multimodal=False):
- messages = []
- for entry in dialogue_history:
- role = entry['role']
- msg_type = entry.get('type', MessageType.TEXT)
- fmt_time = DialogueManager.format_timestamp(entry['timestamp'])
- if msg_type in (MessageType.IMAGE_GW, MessageType.IMAGE_QW, MessageType.GIF):
- if multimodal:
- messages.append({
- "role": role,
- "content": [
- {"type": "image_url", "image_url": {"url": entry["content"]}}
- ]
- })
- else:
- logger.warning("Image in non-multimodal mode")
- messages.append({
- "role": role,
- "content": "[{}] {}".format(fmt_time, '[图片]')
- })
- else:
- messages.append({
- "role": role,
- "content": '[{}] {}'.format(fmt_time, entry["content"])
- })
- # 添加一条前缀用于 约束时间场景
- msg_prefix = '[{}]'.format(current_time)
- messages.append({'role': 'assistant', 'content': msg_prefix})
- return messages
- @staticmethod
- def compose_chat_messages_coze(dialogue_history, current_time, staff_id, user_id):
- messages = []
- # 如果system后的第1条消息不为user,需要在最开始补一条user消息,否则会吞assistant消息
- if len(dialogue_history) > 0 and dialogue_history[0]['role'] != 'user':
- fmt_time = DialogueManager.format_timestamp(dialogue_history[0]['timestamp'])
- messages.append(cozepy.Message.build_user_question_text(f'[{fmt_time}] '))
- # coze最后一条消息必须为user,且可能吞掉连续的user消息,故强制增加一条空消息(可参与合并)
- dialogue_history.append({
- 'role': 'user',
- 'content': ' ',
- 'timestamp': int(datetime.strptime(current_time, '%Y-%m-%d %H:%M:%S').timestamp() * 1000),
- })
- # 将连续的同一角色的消息做聚合,避免coze吞消息
- messages_to_aggr = []
- objects_to_aggr = []
- last_message_role = None
- for entry in dialogue_history:
- if not entry['content']:
- logger.warning("staff[{}], user[{}], role[{}]: empty content in dialogue history".format(
- staff_id, user_id, entry['role']
- ))
- continue
- role = entry['role']
- if role != last_message_role:
- if objects_to_aggr:
- if last_message_role != 'user':
- pass
- else:
- text_message = '\n'.join(messages_to_aggr)
- object_string_list = []
- for object_entry in objects_to_aggr:
- # FIXME: 其它消息类型的支持
- object_string_list.append(cozepy.MessageObjectString.build_image(file_url=object_entry['content']))
- object_string_list.append(cozepy.MessageObjectString.build_text(text_message))
- messages.append(cozepy.Message.build_user_question_objects(object_string_list))
- elif messages_to_aggr:
- aggregated_message = '\n'.join(messages_to_aggr)
- messages.append(DialogueManager.build_chat_message(
- last_message_role, aggregated_message, ChatServiceType.COZE_CHAT))
- objects_to_aggr = []
- messages_to_aggr = []
- last_message_role = role
- if entry.get('type', MessageType.TEXT) in (MessageType.IMAGE_GW, MessageType.IMAGE_QW, MessageType.GIF):
- # 多模态消息必须用特殊的聚合方式,一个object_string数组中只能有一个文字消息,但可以有多个图片
- if role == 'user':
- objects_to_aggr.append(entry)
- else:
- logger.warning("staff[{}], user[{}]: unsupported message type [{}] in assistant role".format(
- staff_id, user_id, entry['type']
- ))
- else:
- messages_to_aggr.append(DialogueManager.format_dialogue_content(entry))
- # 如果有未聚合的object消息,需要特殊处理
- if objects_to_aggr:
- if last_message_role != 'user':
- pass
- else:
- text_message = '\n'.join(messages_to_aggr)
- object_string_list = []
- for object_entry in objects_to_aggr:
- # FIXME: 其它消息类型的支持
- object_string_list.append(cozepy.MessageObjectString.build_image(file_url=object_entry['content']))
- object_string_list.append(cozepy.MessageObjectString.build_text(text_message))
- messages.append(cozepy.Message.build_user_question_objects(object_string_list))
- elif messages_to_aggr:
- aggregated_message = '\n'.join(messages_to_aggr)
- messages.append(DialogueManager.build_chat_message(
- last_message_role, aggregated_message, ChatServiceType.COZE_CHAT))
- return messages
- def build_active_greeting_config(self, user_tags: List[str]):
- # FIXME: 这里的抽象不好,短期支持人为配置实验
- # 由于产运要求,指定使用GPT-4o模型
- chat_config = {'user_id': self.user_id, 'model_name': chat_service.OPENAI_MODEL_GPT_4o}
- prompt_context = self.get_prompt_context(None)
- current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- system_message = {'role': 'system', 'content': 'You are a helpful AI assistant.'}
- # TODO: 随机选择一个prompt 或 带策略选择 或根据用户标签选择
- # TODO:需要区分用户是否有历史交互、是否发送过相似内容
- greeting_prompts = [
- prompt_templates.GREETING_WITH_IMAGE_GAME,
- prompt_templates.GREETING_WITH_NAME_POETRY,
- prompt_templates.GREETING_WITH_AVATAR_STORY
- ]
- # 默认随机选择
- selected_prompt = greeting_prompts[random.randint(0, len(greeting_prompts) - 1)]
- # 实验配置
- tag_to_greeting_map = {
- '04W4-AA-1': prompt_templates.GREETING_WITH_NAME_POETRY,
- '04W4-AA-2': prompt_templates.GREETING_WITH_AVATAR_STORY,
- '04W4-AA-3': prompt_templates.GREETING_WITH_INTEREST_QUERY,
- '04W4-AA-4': prompt_templates.GREETING_WITH_CALENDAR,
- }
- for tag in user_tags:
- if tag in tag_to_greeting_map:
- selected_prompt = tag_to_greeting_map[tag]
- prompt = selected_prompt.format(**prompt_context)
- user_message = {'role': 'user', 'content': prompt}
- messages = [system_message, user_message]
- if selected_prompt in (
- prompt_templates.GREETING_WITH_AVATAR_STORY,
- prompt_templates.GREETING_WITH_INTEREST_QUERY,
- ):
- messages.append({
- "role": 'user',
- "content": [
- {"type": "image_url", "image_url": {"url": self.user_profile['avatar']}}
- ]
- })
- chat_config['use_multimodal_model'] = True
- chat_config['messages'] = messages
- return chat_config
- def build_chat_configuration(
- self,
- user_message: Optional[str] = None,
- chat_service_type: ChatServiceType = ChatServiceType.OPENAI_COMPATIBLE,
- overwrite_context: Optional[Dict] = None
- ) -> Dict:
- """
- 参数:
- user_message: 当前用户消息,如果是主动交互则为None
- 返回:
- 消息列表
- """
- dialogue_history = self.history_dialogue_service.get_dialogue_history(self.staff_id, self.user_id)
- logger.debug("staff[{}], user[{}], recent dialogue_history: {}".format(
- self.staff_id, self.user_id, dialogue_history[-20:]
- ))
- messages = []
- config = {
- 'user_id': self.user_id
- }
- prompt_context = self.get_prompt_context(user_message)
- if overwrite_context:
- prompt_context.update(overwrite_context)
- # FIXME(zhoutian): time in string type
- current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- if overwrite_context and 'current_time' in overwrite_context:
- current_time = overwrite_context.get('current_time')
- need_multimodal = self.need_multimodal_model(dialogue_history)
- config['use_multimodal_model'] = need_multimodal
- if chat_service_type == ChatServiceType.OPENAI_COMPATIBLE:
- system_message = self._create_system_message(prompt_context)
- messages.append(system_message)
- messages.extend(self.compose_chat_messages_openai_compatible(dialogue_history, current_time, need_multimodal))
- elif chat_service_type == ChatServiceType.COZE_CHAT:
- dialogue_history = dialogue_history[-95:] # Coze最多支持100条,还需要附加系统消息
- messages = self.compose_chat_messages_coze(dialogue_history, current_time, self.staff_id, self.user_id)
- custom_variables = {}
- for k, v in prompt_context.items():
- custom_variables[k] = str(v)
- custom_variables.pop('user_profile', None)
- config['custom_variables'] = custom_variables
- config['bot_id'] = self._select_coze_bot(self.current_state, dialogue_history, need_multimodal)
- #FIXME(zhoutian): 临时报警
- if user_message and not messages:
- logger.error(f"staff[{self.staff_id}], user[{self.user_id}]: inconsistency in messages")
- config['messages'] = messages
- return config
- @staticmethod
- def format_timestamp(timestamp_ms):
- return datetime.fromtimestamp(timestamp_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
- @staticmethod
- def format_dialogue_content(dialogue_entry):
- fmt_time = DialogueManager.format_timestamp(dialogue_entry['timestamp'])
- content = '[{}] {}'.format(fmt_time, dialogue_entry['content'])
- return content
- @staticmethod
- def build_chat_message(role, content, chat_service_type: ChatServiceType):
- if chat_service_type == ChatServiceType.COZE_CHAT:
- if role == 'user':
- return cozepy.Message.build_user_question_text(content)
- elif role == 'assistant':
- return cozepy.Message.build_assistant_answer(content)
- else:
- return {'role': role, 'content': content}
- if __name__ == '__main__':
- state_cache = DialogueStateCache()
- state_cache.set_state('1688854492669990', '7881302581935903', DialogueState.CHITCHAT, DialogueState.GREETING)
|