agent_service.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import re
  5. import signal
  6. import sys
  7. import time
  8. import random
  9. from typing import Dict, List, Tuple, Any, Optional
  10. import logging
  11. from datetime import datetime, timedelta
  12. import traceback
  13. import apscheduler.triggers.cron
  14. from apscheduler.schedulers.background import BackgroundScheduler
  15. import chat_service
  16. import configs
  17. import logging_service
  18. from configs import apollo_config
  19. from logging_service import logger
  20. from chat_service import CozeChat, ChatServiceType
  21. from dialogue_manager import DialogueManager, DialogueState, DialogueStateCache
  22. from rate_limiter import MessageSenderRateLimiter
  23. from response_type_detector import ResponseTypeDetector
  24. from user_manager import UserManager, LocalUserManager, MySQLUserManager, MySQLUserRelationManager, UserRelationManager, \
  25. LocalUserRelationManager
  26. from openai import OpenAI
  27. from message_queue_backend import MessageQueueBackend, MemoryQueueBackend, AliyunRocketMQQueueBackend
  28. from user_profile_extractor import UserProfileExtractor
  29. import threading
  30. from message import MessageType, Message, MessageChannel
  31. class AgentService:
  32. def __init__(
  33. self,
  34. receive_backend: MessageQueueBackend,
  35. send_backend: MessageQueueBackend,
  36. human_backend: MessageQueueBackend,
  37. user_manager: UserManager,
  38. user_relation_manager: UserRelationManager,
  39. chat_service_type: ChatServiceType = ChatServiceType.OPENAI_COMPATIBLE
  40. ):
  41. self.receive_queue = receive_backend
  42. self.send_queue = send_backend
  43. self.human_queue = human_backend
  44. # 核心服务模块
  45. self.agent_state_cache = DialogueStateCache()
  46. self.user_manager = user_manager
  47. self.user_relation_manager = user_relation_manager
  48. self.user_profile_extractor = UserProfileExtractor()
  49. self.response_type_detector = ResponseTypeDetector()
  50. self.agent_registry: Dict[str, DialogueManager] = {}
  51. self.config = configs.get()
  52. chat_config = self.config['chat_api']['openai_compatible']
  53. self.text_model_name = chat_config['text_model']
  54. self.multimodal_model_name = chat_config['multimodal_model']
  55. self.text_model_client = chat_service.OpenAICompatible.create_client(self.text_model_name)
  56. self.multimodal_model_client = chat_service.OpenAICompatible.create_client(self.multimodal_model_name)
  57. coze_config = configs.get()['chat_api']['coze']
  58. coze_oauth_app = CozeChat.get_oauth_app(
  59. coze_config['oauth_client_id'], coze_config['private_key_path'], str(coze_config['public_key_id']),
  60. account_id=coze_config.get('account_id', None)
  61. )
  62. self.coze_client = CozeChat(
  63. base_url=chat_service.COZE_CN_BASE_URL,
  64. auth_app=coze_oauth_app
  65. )
  66. self.chat_service_type = chat_service_type
  67. # 定时任务调度器
  68. self.scheduler = None
  69. self.scheduler_mode = self.config.get('system', {}).get('scheduler_mode', 'local')
  70. self.scheduler_queue = None
  71. self.msg_scheduler_thread = None
  72. self.running = False
  73. self.process_thread = None
  74. self._sigint_cnt = 0
  75. self.send_rate_limiter = MessageSenderRateLimiter()
  76. def setup_initiative_conversations(self, schedule_params: Optional[Dict] = None):
  77. if not schedule_params:
  78. schedule_params = {'hour': '8,16,20'}
  79. self.scheduler.add_job(
  80. self._check_initiative_conversations,
  81. apscheduler.triggers.cron.CronTrigger(**schedule_params)
  82. )
  83. def setup_scheduler(self):
  84. self.scheduler = BackgroundScheduler()
  85. if self.scheduler_mode == 'mq':
  86. logging.info("setup event message scheduler with MQ")
  87. mq_conf = self.config['mq']
  88. topic = mq_conf['scheduler_topic']
  89. self.scheduler_queue = AliyunRocketMQQueueBackend(
  90. mq_conf['endpoints'],
  91. mq_conf['instance_id'],
  92. topic,
  93. has_consumer=True, has_producer=True,
  94. group_id=mq_conf['scheduler_group'],
  95. topic_type='DELAY'
  96. )
  97. self.msg_scheduler_thread = threading.Thread(target=self.process_scheduler_events)
  98. self.msg_scheduler_thread.start()
  99. self.scheduler.start()
  100. def process_scheduler_events(self):
  101. while self.running:
  102. msg = self.scheduler_queue.consume()
  103. if msg:
  104. try:
  105. self.process_scheduler_event(msg)
  106. self.scheduler_queue.ack(msg)
  107. except Exception as e:
  108. logger.error("Error processing scheduler event: {}".format(e))
  109. time.sleep(1)
  110. logger.info("Scheduler event processing thread exit")
  111. def process_scheduler_event(self, msg: Message):
  112. if msg.type == MessageType.AGGREGATION_TRIGGER:
  113. # 延迟触发的消息,需放入接收队列以驱动Agent运转
  114. self.receive_queue.produce(msg)
  115. else:
  116. logger.warning(f"Unknown message type: {msg.type}")
  117. def _get_agent_instance(self, staff_id: str, user_id: str) -> DialogueManager:
  118. """获取Agent实例"""
  119. agent_key = 'agent_{}_{}'.format(staff_id, user_id)
  120. if agent_key not in self.agent_registry:
  121. self.agent_registry[agent_key] = DialogueManager(
  122. staff_id, user_id, self.user_manager, self.agent_state_cache)
  123. return self.agent_registry[agent_key]
  124. def process_messages(self):
  125. """持续处理接收队列消息"""
  126. while self.running:
  127. message = self.receive_queue.consume()
  128. if message:
  129. try:
  130. self.process_single_message(message)
  131. self.receive_queue.ack(message)
  132. except Exception as e:
  133. logger.error("Error processing message: {}".format(e))
  134. traceback.print_exc()
  135. time.sleep(1)
  136. logger.info("Message processing thread exit")
  137. def start(self, blocking=False):
  138. self.running = True
  139. self.process_thread = threading.Thread(target=service.process_messages)
  140. self.process_thread.start()
  141. self.setup_scheduler()
  142. # 只有企微场景需要主动发起
  143. if not self.config['debug_flags'].get('disable_active_conversation', False):
  144. schedule_param = self.config['agent_behavior'].get('active_conversation_schedule_param', None)
  145. self.setup_initiative_conversations(schedule_param)
  146. signal.signal(signal.SIGINT, self._handle_sigint)
  147. if blocking:
  148. self.process_thread.join()
  149. def shutdown(self, sync=True):
  150. if not self.running:
  151. raise Exception("Service is not running")
  152. self.running = False
  153. self.scheduler.shutdown()
  154. if sync:
  155. self.process_thread.join()
  156. self.receive_queue.shutdown()
  157. self.send_queue.shutdown()
  158. if self.msg_scheduler_thread:
  159. self.msg_scheduler_thread.join()
  160. self.scheduler_queue.shutdown()
  161. def _handle_sigint(self, signum, frame):
  162. self._sigint_cnt += 1
  163. if self._sigint_cnt == 1:
  164. logger.warning("Try to shutdown gracefully...")
  165. self.shutdown(sync=True)
  166. else:
  167. logger.warning("Forcing exit")
  168. sys.exit(0)
  169. def _update_user_profile(self, user_id, user_profile, recent_dialogue: List[Dict]):
  170. profile_to_update = self.user_profile_extractor.extract_profile_info(user_profile, recent_dialogue)
  171. if not profile_to_update:
  172. logger.debug("user_id: {}, no profile info extracted".format(user_id))
  173. return
  174. logger.warning("update user profile: {}".format(profile_to_update))
  175. if profile_to_update.get('interaction_frequency', None) == 'stopped':
  176. # 和企微日常push联动,减少对用户的干扰
  177. if self.user_relation_manager.stop_user_daily_push(user_id):
  178. logger.warning(f"user[{user_id}]: daily push set to be stopped")
  179. merged_profile = self.user_profile_extractor.merge_profile_info(user_profile, profile_to_update)
  180. self.user_manager.save_user_profile(user_id, merged_profile)
  181. return merged_profile
  182. def _schedule_aggregation_trigger(self, staff_id: str, user_id: str, delay_sec: int):
  183. logger.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
  184. message_ts = int((time.time() + delay_sec) * 1000)
  185. msg = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
  186. # 系统消息使用特定的msgId,无实际意义
  187. msg.msgId = -MessageType.AGGREGATION_TRIGGER.value
  188. if self.scheduler_mode == 'mq':
  189. self.scheduler_queue.produce(msg)
  190. else:
  191. self.scheduler.add_job(lambda: self.receive_queue.produce(msg),
  192. 'date',
  193. run_date=datetime.now() + timedelta(seconds=delay_sec))
  194. def process_single_message(self, message: Message):
  195. user_id = message.sender
  196. staff_id = message.receiver
  197. # 获取用户信息和Agent实例
  198. user_profile = self.user_manager.get_user_profile(user_id)
  199. agent = self._get_agent_instance(staff_id, user_id)
  200. # 更新对话状态
  201. logger.debug("process message: {}".format(message))
  202. need_response, message_text = agent.update_state(message)
  203. logger.debug("user: {}, next state: {}".format(user_id, agent.current_state))
  204. # 根据状态路由消息
  205. try:
  206. if agent.is_in_human_intervention():
  207. self._route_to_human_intervention(user_id, message)
  208. elif agent.current_state == DialogueState.MESSAGE_AGGREGATING:
  209. if message.type != MessageType.AGGREGATION_TRIGGER:
  210. # 产生一个触发器,但是不能由触发器递归产生
  211. logger.debug("user: {}, waiting next message for aggregation".format(user_id))
  212. self._schedule_aggregation_trigger(staff_id, user_id, agent.message_aggregation_sec)
  213. elif need_response:
  214. # 先更新用户画像再处理回复
  215. self._update_user_profile(user_id, user_profile, agent.dialogue_history[-10:])
  216. resp = self._get_chat_response(user_id, agent, message_text)
  217. if resp:
  218. recent_dialogue = agent.dialogue_history[-10:]
  219. if len(recent_dialogue) < 2 or staff_id not in ('1688855931724582', '1688854492669990'):
  220. message_type = MessageType.TEXT
  221. else:
  222. message_type = self.response_type_detector.detect_type(
  223. recent_dialogue[:-1], recent_dialogue[-1], enable_random=True)
  224. self._send_response(staff_id, user_id, resp, message_type)
  225. else:
  226. logger.debug(f"staff[{staff_id}], user[{user_id}]: do not need response")
  227. # 当前消息处理成功,commit并持久化agent状态
  228. agent.persist_state()
  229. except Exception as e:
  230. agent.rollback_state()
  231. raise e
  232. def _send_response(self, staff_id, user_id, response, message_type: MessageType, skip_check=False):
  233. logger.warning(f"staff[{staff_id}] user[{user_id}]: response[{message_type}] {response}")
  234. current_ts = int(time.time() * 1000)
  235. user_tags = self.user_relation_manager.get_user_tags(user_id)
  236. white_list_tags = set(apollo_config.get_json_value("agent_response_whitelist_tags"))
  237. hit_white_list_tags = len(set(user_tags).intersection(white_list_tags)) > 0
  238. # FIXME(zhoutian)
  239. # 测试期间临时逻辑,只发送特定的账号或特定用户
  240. staff_white_lists = set(apollo_config.get_json_value("agent_response_whitelist_staffs"))
  241. if not (staff_id in staff_white_lists or hit_white_list_tags or skip_check):
  242. logger.warning(f"staff[{staff_id}] user[{user_id}]: skip reply")
  243. return
  244. self.send_rate_limiter.wait_for_sending(staff_id, response)
  245. self.send_queue.produce(
  246. Message.build(message_type, MessageChannel.CORP_WECHAT,
  247. staff_id, user_id, response, current_ts)
  248. )
  249. def _route_to_human_intervention(self, user_id: str, origin_message: Message):
  250. """路由到人工干预"""
  251. self.human_queue.produce(Message.build(
  252. MessageType.TEXT,
  253. origin_message.channel,
  254. origin_message.sender,
  255. origin_message.receiver,
  256. "用户对话需人工介入,用户名:{}".format(user_id),
  257. int(time.time() * 1000)
  258. ))
  259. def _check_initiative_conversations(self):
  260. logger.info("start to check initiative conversations")
  261. if not DialogueManager.is_time_suitable_for_active_conversation():
  262. logger.info("time is not suitable for active conversation")
  263. return
  264. white_list_tags = set(apollo_config.get_json_value('agent_initiate_whitelist_tags'))
  265. first_initiate_tags = set(apollo_config.get_json_value('agent_first_initiate_whitelist_tags', []))
  266. # 合并白名单,减少配置成本
  267. white_list_tags.update(first_initiate_tags)
  268. voice_tags = set(apollo_config.get_json_value('agent_initiate_by_voice_tags'))
  269. """定时检查主动发起对话"""
  270. for staff_user in self.user_relation_manager.list_staff_users():
  271. staff_id = staff_user['staff_id']
  272. user_id = staff_user['user_id']
  273. agent = self._get_agent_instance(staff_id, user_id)
  274. should_initiate = agent.should_initiate_conversation()
  275. user_tags = self.user_relation_manager.get_user_tags(user_id)
  276. if configs.get_env() != 'dev' and not white_list_tags.intersection(user_tags):
  277. should_initiate = False
  278. if should_initiate:
  279. logger.warning(f"user[{user_id}], tags{user_tags}: initiate conversation")
  280. # FIXME:虽然需要主动唤起的用户同时发来消息的概率很低,但仍可能会有并发冲突 需要并入事件驱动框架
  281. agent.do_state_change(DialogueState.GREETING)
  282. try:
  283. if agent.previous_state == DialogueState.INITIALIZED or first_initiate_tags.intersection(user_tags):
  284. # 完全无交互历史的用户才使用此策略,但新用户接入即会产生“我已添加了你”的消息将Agent初始化
  285. # 因此存量用户无法使用该状态做实验
  286. # TODO:增加基于对话历史的判断、策略去重;如果对话间隔过长需要使用长期记忆检索;在无长期记忆时,可采用用户添加时间来判断
  287. resp = self._generate_active_greeting_message(agent, user_tags)
  288. else:
  289. resp = self._get_chat_response(user_id, agent, None)
  290. if resp:
  291. if set(user_tags).intersection(voice_tags):
  292. message_type = MessageType.VOICE
  293. else:
  294. message_type = MessageType.TEXT
  295. self._send_response(staff_id, user_id, resp, message_type, skip_check=True)
  296. agent.persist_state()
  297. except Exception as e:
  298. # FIXME:虽然需要主动唤起的用户同时发来消息的概率很低,但仍可能会有并发冲突
  299. agent.rollback_state()
  300. logger.error("Error in active greeting: {}".format(e))
  301. else:
  302. logger.debug(f"user[{user_id}], do not initiate conversation")
  303. def _generate_active_greeting_message(self, agent: DialogueManager, user_tags: List[str]=None):
  304. chat_config = agent.build_active_greeting_config(user_tags)
  305. chat_response = self._call_chat_api(chat_config, ChatServiceType.OPENAI_COMPATIBLE)
  306. chat_response = self.sanitize_response(chat_response)
  307. if response := agent.generate_response(chat_response):
  308. return response
  309. else:
  310. logger.warning(f"staff[{agent.staff_id}] user[{agent.user_id}]: no response generated")
  311. return None
  312. def _get_chat_response(self, user_id: str, agent: DialogueManager,
  313. user_message: Optional[str]):
  314. """处理LLM响应"""
  315. chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
  316. config_for_logging = chat_config.copy()
  317. config_for_logging['messages'] = config_for_logging['messages'][-20:]
  318. logger.debug(config_for_logging)
  319. chat_response = self._call_chat_api(chat_config, self.chat_service_type)
  320. chat_response = self.sanitize_response(chat_response)
  321. if response := agent.generate_response(chat_response):
  322. return response
  323. else:
  324. logger.warning(f"staff[{agent.staff_id}] user[{user_id}]: no response generated")
  325. return None
  326. def _call_chat_api(self, chat_config: Dict, chat_service_type: ChatServiceType) -> str:
  327. if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
  328. return 'LLM模拟回复 {}'.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
  329. if chat_service_type == ChatServiceType.OPENAI_COMPATIBLE:
  330. # 指定了LLM模型则优先使用指定模型
  331. if chat_config.get('model_name', None):
  332. llm_client = chat_service.OpenAICompatible.create_client(chat_config['model_name'])
  333. chat_completion = llm_client.chat.completions.create(
  334. messages=chat_config['messages'],
  335. model=chat_config['model_name'],
  336. )
  337. elif chat_config.get('use_multimodal_model', False):
  338. chat_completion = self.multimodal_model_client.chat.completions.create(
  339. messages=chat_config['messages'],
  340. model=self.multimodal_model_name,
  341. )
  342. else:
  343. chat_completion = self.text_model_client.chat.completions.create(
  344. messages=chat_config['messages'],
  345. model=self.text_model_name,
  346. )
  347. response = chat_completion.choices[0].message.content
  348. elif chat_service_type == ChatServiceType.COZE_CHAT:
  349. bot_user_id = 'qywx_{}'.format(chat_config['user_id'])
  350. response = self.coze_client.create(
  351. chat_config['bot_id'], bot_user_id, chat_config['messages'],
  352. chat_config['custom_variables']
  353. )
  354. else:
  355. raise Exception('Unsupported chat service type: {}'.format(self.chat_service_type))
  356. return response
  357. @staticmethod
  358. def sanitize_response(response: str):
  359. pattern = r'\[?\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\]?'
  360. response = re.sub(pattern, '', response)
  361. response = response.strip()
  362. return response
  363. if __name__ == "__main__":
  364. config = configs.get()
  365. logging_service.setup_root_logger()
  366. logger.warning("current env: {}".format(configs.get_env()))
  367. scheduler_logger = logging.getLogger('apscheduler')
  368. scheduler_logger.setLevel(logging.WARNING)
  369. use_aliyun_mq = config['debug_flags']['use_aliyun_mq']
  370. # 初始化不同队列的后端
  371. if use_aliyun_mq:
  372. receive_queue = AliyunRocketMQQueueBackend(
  373. config['mq']['endpoints'],
  374. config['mq']['instance_id'],
  375. config['mq']['receive_topic'],
  376. has_consumer=True, has_producer=True,
  377. group_id=config['mq']['receive_group'],
  378. topic_type='FIFO'
  379. )
  380. send_queue = AliyunRocketMQQueueBackend(
  381. config['mq']['endpoints'],
  382. config['mq']['instance_id'],
  383. config['mq']['send_topic'],
  384. has_consumer=False, has_producer=True,
  385. topic_type='FIFO'
  386. )
  387. else:
  388. receive_queue = MemoryQueueBackend()
  389. send_queue = MemoryQueueBackend()
  390. human_queue = MemoryQueueBackend()
  391. # 初始化用户管理服务
  392. # FIXME(zhoutian): 如果不使用MySQL,此数据库配置非必须
  393. user_db_config = config['storage']['user']
  394. staff_db_config = config['storage']['staff']
  395. wecom_db_config = config['storage']['user_relation']
  396. if config['debug_flags'].get('use_local_user_storage', False):
  397. user_manager = LocalUserManager()
  398. user_relation_manager = LocalUserRelationManager()
  399. else:
  400. user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
  401. user_relation_manager = MySQLUserRelationManager(
  402. user_db_config['mysql'], wecom_db_config['mysql'],
  403. config['storage']['staff']['table'],
  404. user_db_config['table'],
  405. wecom_db_config['table']['staff'],
  406. wecom_db_config['table']['relation'],
  407. wecom_db_config['table']['user']
  408. )
  409. # 创建Agent服务
  410. service = AgentService(
  411. receive_backend=receive_queue,
  412. send_backend=send_queue,
  413. human_backend=human_queue,
  414. user_manager=user_manager,
  415. user_relation_manager=user_relation_manager,
  416. chat_service_type=ChatServiceType.COZE_CHAT
  417. )
  418. if not config['debug_flags'].get('console_input', False):
  419. service.start(blocking=True)
  420. sys.exit(0)
  421. else:
  422. service.start()
  423. message_id = 0
  424. while service.running:
  425. print("Input next message: ")
  426. text = sys.stdin.readline().strip()
  427. if not text:
  428. continue
  429. message_id += 1
  430. sender = '7881301903997433'
  431. receiver = '1688855931724582'
  432. if text in (MessageType.AGGREGATION_TRIGGER.name,
  433. MessageType.HUMAN_INTERVENTION_END.name):
  434. message = Message.build(
  435. MessageType.__members__.get(text),
  436. MessageChannel.CORP_WECHAT,
  437. sender, receiver, None, int(time.time() * 1000))
  438. else:
  439. message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
  440. sender,receiver, text, int(time.time() * 1000)
  441. )
  442. message.msgId = message_id
  443. receive_queue.produce(message)
  444. time.sleep(0.1)