agent_service.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import re
  5. import sys
  6. import time
  7. import random
  8. from typing import Dict, List, Tuple, Any, Optional
  9. import logging
  10. from datetime import datetime, timedelta
  11. import traceback
  12. import apscheduler.triggers.cron
  13. from apscheduler.schedulers.background import BackgroundScheduler
  14. import chat_service
  15. import configs
  16. import logging_service
  17. from configs import apollo_config
  18. from logging_service import logger
  19. from chat_service import CozeChat, ChatServiceType
  20. from dialogue_manager import DialogueManager, DialogueState, DialogueStateCache
  21. from response_type_detector import ResponseTypeDetector
  22. from user_manager import UserManager, LocalUserManager, MySQLUserManager, MySQLUserRelationManager, UserRelationManager
  23. from openai import OpenAI
  24. from message_queue_backend import MessageQueueBackend, MemoryQueueBackend, AliyunRocketMQQueueBackend
  25. from user_profile_extractor import UserProfileExtractor
  26. import threading
  27. from message import MessageType, Message, MessageChannel
  28. class AgentService:
  29. def __init__(
  30. self,
  31. receive_backend: MessageQueueBackend,
  32. send_backend: MessageQueueBackend,
  33. human_backend: MessageQueueBackend,
  34. user_manager: UserManager,
  35. user_relation_manager: UserRelationManager,
  36. chat_service_type: ChatServiceType = ChatServiceType.OPENAI_COMPATIBLE
  37. ):
  38. self.receive_queue = receive_backend
  39. self.send_queue = send_backend
  40. self.human_queue = human_backend
  41. # 核心服务模块
  42. self.agent_state_cache = DialogueStateCache()
  43. self.user_manager = user_manager
  44. self.user_relation_manager = user_relation_manager
  45. self.user_profile_extractor = UserProfileExtractor()
  46. self.response_type_detector = ResponseTypeDetector()
  47. self.agent_registry: Dict[str, DialogueManager] = {}
  48. chat_config = configs.get()['chat_api']['openai_compatible']
  49. self.text_model_name = chat_config['text_model']
  50. self.multimodal_model_name = chat_config['multimodal_model']
  51. self.text_model_client = chat_service.OpenAICompatible.create_client(self.text_model_name)
  52. self.multimodal_model_client = chat_service.OpenAICompatible.create_client(self.multimodal_model_name)
  53. coze_config = configs.get()['chat_api']['coze']
  54. coze_oauth_app = CozeChat.get_oauth_app(
  55. coze_config['oauth_client_id'], coze_config['private_key_path'], str(coze_config['public_key_id']),
  56. account_id=coze_config.get('account_id', None)
  57. )
  58. self.coze_client = CozeChat(
  59. base_url=chat_service.COZE_CN_BASE_URL,
  60. auth_app=coze_oauth_app
  61. )
  62. self.chat_service_type = chat_service_type
  63. # 定时任务调度器
  64. self.scheduler = BackgroundScheduler()
  65. self.scheduler.start()
  66. self.limit_initiative_conversation_rate = True
  67. def setup_initiative_conversations(self, schedule_params: Optional[Dict] = None):
  68. if not schedule_params:
  69. schedule_params = {'hour': '8,16,20'}
  70. self.scheduler.add_job(
  71. self._check_initiative_conversations,
  72. apscheduler.triggers.cron.CronTrigger(**schedule_params)
  73. )
  74. def _get_agent_instance(self, staff_id: str, user_id: str) -> DialogueManager:
  75. """获取Agent实例"""
  76. agent_key = 'agent_{}_{}'.format(staff_id, user_id)
  77. if agent_key not in self.agent_registry:
  78. self.agent_registry[agent_key] = DialogueManager(
  79. staff_id, user_id, self.user_manager, self.agent_state_cache)
  80. return self.agent_registry[agent_key]
  81. def process_messages(self):
  82. """持续处理接收队列消息"""
  83. while True:
  84. message = self.receive_queue.consume()
  85. if message:
  86. try:
  87. self.process_single_message(message)
  88. self.receive_queue.ack(message)
  89. except Exception as e:
  90. logger.error("Error processing message: {}".format(e))
  91. traceback.print_exc()
  92. time.sleep(1)
  93. def _update_user_profile(self, user_id, user_profile, recent_dialogue: List[Dict]):
  94. profile_to_update = self.user_profile_extractor.extract_profile_info(user_profile, recent_dialogue)
  95. if not profile_to_update:
  96. logger.debug("user_id: {}, no profile info extracted".format(user_id))
  97. return
  98. logger.warning("update user profile: {}".format(profile_to_update))
  99. merged_profile = self.user_profile_extractor.merge_profile_info(user_profile, profile_to_update)
  100. self.user_manager.save_user_profile(user_id, merged_profile)
  101. return merged_profile
  102. def _schedule_aggregation_trigger(self, staff_id: str, user_id: str, delay_sec: int):
  103. logger.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
  104. message_ts = int((time.time() + delay_sec) * 1000)
  105. message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
  106. # 系统消息使用特定的msgId,无实际意义
  107. message.msgId = -MessageType.AGGREGATION_TRIGGER.value
  108. self.scheduler.add_job(lambda: self.receive_queue.produce(message),
  109. 'date',
  110. run_date=datetime.now() + timedelta(seconds=delay_sec))
  111. def process_single_message(self, message: Message):
  112. user_id = message.sender
  113. staff_id = message.receiver
  114. # 获取用户信息和Agent实例
  115. user_profile = self.user_manager.get_user_profile(user_id)
  116. agent = self._get_agent_instance(staff_id, user_id)
  117. # 更新对话状态
  118. logger.debug("process message: {}".format(message))
  119. need_response, message_text = agent.update_state(message)
  120. logger.debug("user: {}, next state: {}".format(user_id, agent.current_state))
  121. # 根据状态路由消息
  122. try:
  123. if agent.is_in_human_intervention():
  124. self._route_to_human_intervention(user_id, message)
  125. elif agent.current_state == DialogueState.MESSAGE_AGGREGATING:
  126. if message.type != MessageType.AGGREGATION_TRIGGER:
  127. # 产生一个触发器,但是不能由触发器递归产生
  128. logger.debug("user: {}, waiting next message for aggregation".format(user_id))
  129. self._schedule_aggregation_trigger(staff_id, user_id, agent.message_aggregation_sec)
  130. elif need_response:
  131. # 先更新用户画像再处理回复
  132. self._update_user_profile(user_id, user_profile, agent.dialogue_history[-10:])
  133. resp = self._get_chat_response(user_id, agent, message_text)
  134. if resp:
  135. recent_dialogue = agent.dialogue_history[-10:]
  136. if len(recent_dialogue) < 2:
  137. message_type = MessageType.TEXT
  138. else:
  139. message_type = self.response_type_detector.detect_type(recent_dialogue[:-1], recent_dialogue[-1])
  140. self._send_response(staff_id, user_id, resp, message_type)
  141. else:
  142. logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
  143. # 当前消息处理成功,持久化agent状态
  144. agent.persist_state()
  145. except Exception as e:
  146. agent.rollback_state()
  147. raise e
  148. def _send_response(self, staff_id, user_id, response, message_type: MessageType):
  149. logger.warning(f"staff[{staff_id}] user[{user_id}]: response[{message_type}] {response}")
  150. current_ts = int(time.time() * 1000)
  151. user_tags = self.user_relation_manager.get_user_tags(user_id)
  152. white_list_tags = set(apollo_config.get_json_value("agent_response_whitelist_tags"))
  153. hit_white_list_tags = len(set(user_tags).intersection(white_list_tags)) > 0
  154. # FIXME(zhoutian)
  155. # 测试期间临时逻辑,只发送特定的账号或特定用户
  156. staff_white_lists = set(apollo_config.get_json_value("agent_response_whitelist_staffs"))
  157. if not (staff_id in staff_white_lists or hit_white_list_tags):
  158. logger.warning(f"staff[{staff_id}] user[{user_id}]: skip reply")
  159. return None
  160. self.send_queue.produce(
  161. Message.build(message_type, MessageChannel.CORP_WECHAT,
  162. staff_id, user_id, response, current_ts)
  163. )
  164. def _route_to_human_intervention(self, user_id: str, origin_message: Message):
  165. """路由到人工干预"""
  166. self.human_queue.produce(Message.build(
  167. MessageType.TEXT,
  168. origin_message.channel,
  169. origin_message.sender,
  170. origin_message.receiver,
  171. "用户对话需人工介入,用户名:{}".format(user_id),
  172. int(time.time() * 1000)
  173. ))
  174. def _check_initiative_conversations(self):
  175. """定时检查主动发起对话"""
  176. for staff_user in self.user_relation_manager.list_staff_users():
  177. staff_id = staff_user['staff_id']
  178. user_id = staff_user['user_id']
  179. agent = self._get_agent_instance(staff_id, user_id)
  180. should_initiate = agent.should_initiate_conversation()
  181. user_tags = self.user_relation_manager.get_user_tags(user_id)
  182. white_list_tags = apollo_config.get_json_value('agent_initiate_whitelist_tags')
  183. if not set(user_tags).intersection(white_list_tags):
  184. should_initiate = False
  185. if should_initiate:
  186. logger.warning("user: {}, initiate conversation".format(user_id))
  187. resp = self._get_chat_response(user_id, agent, None)
  188. if resp:
  189. self._send_response(staff_id, user_id, resp, MessageType.TEXT)
  190. if self.limit_initiative_conversation_rate:
  191. time.sleep(random.randint(10,20))
  192. else:
  193. logger.debug("user: {}, do not initiate conversation".format(user_id))
  194. def _get_chat_response(self, user_id: str, agent: DialogueManager,
  195. user_message: Optional[str]):
  196. """处理LLM响应"""
  197. chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
  198. logger.debug(chat_config)
  199. chat_response = self._call_chat_api(chat_config)
  200. chat_response = self.sanitize_response(chat_response)
  201. if response := agent.generate_response(chat_response):
  202. return response
  203. else:
  204. logger.warning(f"staff[{agent.staff_id}] user[{user_id}]: no response generated")
  205. return None
  206. def _call_chat_api(self, chat_config: Dict) -> str:
  207. if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
  208. return 'LLM模拟回复 {}'.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
  209. if self.chat_service_type == ChatServiceType.OPENAI_COMPATIBLE:
  210. if chat_config['use_multimodal_model']:
  211. chat_completion = self.multimodal_model_client.chat.completions.create(
  212. messages=chat_config['messages'],
  213. model=self.multimodal_model_name,
  214. )
  215. else:
  216. chat_completion = self.text_model_client.chat.completions.create(
  217. messages=chat_config['messages'],
  218. model=self.text_model_client,
  219. )
  220. response = chat_completion.choices[0].message.content
  221. elif self.chat_service_type == ChatServiceType.COZE_CHAT:
  222. bot_user_id = 'qywx_{}'.format(chat_config['user_id'])
  223. response = self.coze_client.create(
  224. chat_config['bot_id'], bot_user_id, chat_config['messages'],
  225. chat_config['custom_variables']
  226. )
  227. else:
  228. raise Exception('Unsupported chat service type: {}'.format(self.chat_service_type))
  229. return response
  230. @staticmethod
  231. def sanitize_response(response: str):
  232. pattern = r'\[?\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\]?'
  233. response = re.sub(pattern, '', response)
  234. response = response.strip()
  235. return response
  236. if __name__ == "__main__":
  237. config = configs.get()
  238. logging_service.setup_root_logger()
  239. logger.warning("current env: {}".format(configs.get_env()))
  240. scheduler_logger = logging.getLogger('apscheduler')
  241. scheduler_logger.setLevel(logging.WARNING)
  242. use_aliyun_mq = config['debug_flags']['use_aliyun_mq']
  243. # 初始化不同队列的后端
  244. if use_aliyun_mq:
  245. receive_queue = AliyunRocketMQQueueBackend(
  246. config['mq']['endpoints'],
  247. config['mq']['instance_id'],
  248. config['mq']['receive_topic'],
  249. has_consumer=True, has_producer=True,
  250. group_id=config['mq']['receive_group']
  251. )
  252. send_queue = AliyunRocketMQQueueBackend(
  253. config['mq']['endpoints'],
  254. config['mq']['instance_id'],
  255. config['mq']['send_topic'],
  256. has_consumer=False, has_producer=True
  257. )
  258. else:
  259. receive_queue = MemoryQueueBackend()
  260. send_queue = MemoryQueueBackend()
  261. human_queue = MemoryQueueBackend()
  262. # 初始化用户管理服务
  263. # FIXME(zhoutian): 如果不使用MySQL,此数据库配置非必须
  264. user_db_config = config['storage']['user']
  265. staff_db_config = config['storage']['staff']
  266. if config['debug_flags'].get('use_local_user_storage', False):
  267. user_manager = LocalUserManager()
  268. else:
  269. user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
  270. wecom_db_config = config['storage']['user_relation']
  271. user_relation_manager = MySQLUserRelationManager(
  272. user_db_config['mysql'], wecom_db_config['mysql'],
  273. config['storage']['staff']['table'],
  274. user_db_config['table'],
  275. wecom_db_config['table']['staff'],
  276. wecom_db_config['table']['relation'],
  277. wecom_db_config['table']['user']
  278. )
  279. # 创建Agent服务
  280. service = AgentService(
  281. receive_backend=receive_queue,
  282. send_backend=send_queue,
  283. human_backend=human_queue,
  284. user_manager=user_manager,
  285. user_relation_manager=user_relation_manager,
  286. chat_service_type=ChatServiceType.COZE_CHAT
  287. )
  288. # 只有企微场景需要主动发起
  289. if not config['debug_flags'].get('disable_active_conversation', False):
  290. schedule_param = config['agent_behavior'].get('schedule_param', None)
  291. service.setup_initiative_conversations(schedule_param)
  292. process_thread = threading.Thread(target=service.process_messages)
  293. process_thread.start()
  294. if not config['debug_flags'].get('console_input', False):
  295. process_thread.join()
  296. sys.exit(0)
  297. message_id = 0
  298. while True:
  299. print("Input next message: ")
  300. text = sys.stdin.readline().strip()
  301. if not text:
  302. continue
  303. message_id += 1
  304. sender = '7881301263964433'
  305. receiver = '1688854492669990'
  306. if text == MessageType.AGGREGATION_TRIGGER.name:
  307. message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.CORP_WECHAT,
  308. sender, receiver, None, int(time.time() * 1000))
  309. else:
  310. message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
  311. sender,receiver, text, int(time.time() * 1000)
  312. )
  313. message.msgId = message_id
  314. receive_queue.produce(message)
  315. time.sleep(0.1)
  316. process_thread.join()