agent_service.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import sys
  5. import time
  6. from typing import Dict, List, Tuple, Any, Optional
  7. import logging
  8. from datetime import datetime, timedelta
  9. import apscheduler.triggers.cron
  10. from apscheduler.schedulers.background import BackgroundScheduler
  11. import chat_service
  12. import configs
  13. import logging_service
  14. from chat_service import CozeChat, ChatServiceType
  15. from dialogue_manager import DialogueManager, DialogueState
  16. from user_manager import UserManager, LocalUserManager, MySQLUserManager, MySQLUserRelationManager, UserRelationManager
  17. from openai import OpenAI
  18. from message_queue_backend import MessageQueueBackend, MemoryQueueBackend, AliyunRocketMQQueueBackend
  19. from user_profile_extractor import UserProfileExtractor
  20. import threading
  21. from message import MessageType, Message, MessageChannel
  22. from logging_service import ColoredFormatter
  23. class AgentService:
  24. def __init__(
  25. self,
  26. receive_backend: MessageQueueBackend,
  27. send_backend: MessageQueueBackend,
  28. human_backend: MessageQueueBackend,
  29. user_manager: UserManager,
  30. user_relation_manager: UserRelationManager,
  31. chat_service_type: ChatServiceType = ChatServiceType.OPENAI_COMPATIBLE
  32. ):
  33. self.receive_queue = receive_backend
  34. self.send_queue = send_backend
  35. self.human_queue = human_backend
  36. # 核心服务模块
  37. self.user_manager = user_manager
  38. self.user_relation_manager = user_relation_manager
  39. self.user_profile_extractor = UserProfileExtractor()
  40. self.agent_registry: Dict[str, DialogueManager] = {}
  41. self.llm_client = OpenAI(
  42. api_key=chat_service.VOLCENGINE_API_TOKEN,
  43. base_url=chat_service.VOLCENGINE_BASE_URL
  44. )
  45. # DeepSeek on Volces
  46. self.model_name = chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3
  47. self.coze_client = CozeChat(
  48. token=chat_service.COZE_API_TOKEN,
  49. base_url=chat_service.COZE_CN_BASE_URL
  50. )
  51. self.chat_service_type = chat_service_type
  52. # 定时任务调度器
  53. self.scheduler = BackgroundScheduler()
  54. self.scheduler.start()
  55. def setup_initiative_conversations(self, schedule_params: Optional[Dict] = None):
  56. if not schedule_params:
  57. schedule_params = {'hour': '8,16,20'}
  58. self.scheduler.add_job(
  59. self._check_initiative_conversations,
  60. apscheduler.triggers.cron.CronTrigger(**schedule_params)
  61. )
  62. def _get_agent_instance(self, staff_id: str, user_id: str) -> DialogueManager:
  63. """获取Agent实例"""
  64. agent_key = 'agent_{}_{}'.format(staff_id, user_id)
  65. if agent_key not in self.agent_registry:
  66. self.agent_registry[agent_key] = DialogueManager(
  67. staff_id, user_id, self.user_manager)
  68. return self.agent_registry[agent_key]
  69. def process_messages(self):
  70. """持续处理接收队列消息"""
  71. while True:
  72. message = self.receive_queue.consume()
  73. if message:
  74. try:
  75. self.process_single_message(message)
  76. except Exception as e:
  77. logging.error("Error processing message: {}".format(e))
  78. # 无论处理是否有异常,都ACK消息
  79. self.receive_queue.ack(message)
  80. time.sleep(1)
  81. def _update_user_profile(self, user_id, user_profile, message: str):
  82. profile_to_update = self.user_profile_extractor.extract_profile_info(user_profile, message)
  83. if not profile_to_update:
  84. logging.debug("user_id: {}, no profile info extracted".format(user_id))
  85. return
  86. logging.warning("update user profile: {}".format(profile_to_update))
  87. merged_profile = self.user_profile_extractor.merge_profile_info(user_profile, profile_to_update)
  88. self.user_manager.save_user_profile(user_id, merged_profile)
  89. return merged_profile
  90. def _schedule_aggregation_trigger(self, staff_id: str, user_id: str, delay_sec: int):
  91. logging.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
  92. message_ts = int((time.time() + delay_sec) * 1000)
  93. message = Message.build(MessageType.AGGREGATION_TRIGGER, MessageChannel.SYSTEM, user_id, staff_id, None, message_ts)
  94. # 系统消息使用特定的msgId,无实际意义
  95. message.msgId = -MessageType.AGGREGATION_TRIGGER.value
  96. self.scheduler.add_job(lambda: self.receive_queue.produce(message),
  97. 'date',
  98. run_date=datetime.now() + timedelta(seconds=delay_sec))
  99. def process_single_message(self, message: Message):
  100. user_id = message.sender
  101. staff_id = message.receiver
  102. # 获取用户信息和Agent实例
  103. user_profile = self.user_manager.get_user_profile(user_id)
  104. agent = self._get_agent_instance(staff_id, user_id)
  105. # 更新对话状态
  106. logging.debug("process message: {}".format(message))
  107. dialogue_state, message_text = agent.update_state(message)
  108. logging.debug("user: {}, next state: {}".format(user_id, dialogue_state))
  109. # 根据状态路由消息
  110. if agent.is_in_human_intervention():
  111. self._route_to_human_intervention(user_id, message)
  112. elif dialogue_state == DialogueState.MESSAGE_AGGREGATING:
  113. if message.type != MessageType.AGGREGATION_TRIGGER:
  114. # 产生一个触发器,但是不能由触发器递归产生
  115. logging.debug("user: {}, waiting next message for aggregation".format(user_id))
  116. self._schedule_aggregation_trigger(staff_id, user_id, agent.message_aggregation_sec)
  117. return
  118. else:
  119. # 先更新用户画像再处理回复
  120. self._update_user_profile(user_id, user_profile, message_text)
  121. self._get_chat_response(user_id, agent, message_text)
  122. def _route_to_human_intervention(self, user_id: str, origin_message: Message):
  123. """路由到人工干预"""
  124. self.human_queue.produce(Message.build(
  125. MessageType.TEXT,
  126. origin_message.channel,
  127. origin_message.sender,
  128. origin_message.receiver,
  129. "用户对话需人工介入,用户名:{}".format(user_id),
  130. int(time.time() * 1000)
  131. ))
  132. def _check_initiative_conversations(self):
  133. """定时检查主动发起对话"""
  134. for staff_user in self.user_relation_manager.list_staff_users():
  135. staff_id = staff_user['staff_id']
  136. user_id = staff_user['user_id']
  137. agent = self._get_agent_instance(staff_id, user_id)
  138. should_initiate = agent.should_initiate_conversation()
  139. if should_initiate:
  140. logging.warning("user: {}, initiate conversation".format(user_id))
  141. self._get_chat_response(user_id, agent, None)
  142. else:
  143. logging.debug("user: {}, do not initiate conversation".format(user_id))
  144. def _get_chat_response(self, user_id: str, agent: DialogueManager,
  145. user_message: Optional[str]):
  146. """处理LLM响应"""
  147. chat_config = agent.build_chat_configuration(user_message, self.chat_service_type)
  148. logging.debug(chat_config)
  149. chat_response = self._call_chat_api(chat_config)
  150. if response := agent.generate_response(chat_response):
  151. logging.warning("user: {}, response: {}".format(user_id, response))
  152. current_ts = int(time.time() * 1000)
  153. self.send_queue.produce(
  154. Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
  155. agent.staff_id, user_id, response, current_ts)
  156. )
  157. def _call_chat_api(self, chat_config: Dict) -> str:
  158. if configs.get().get('debug_flags', {}).get('disable_llm_api_call', False):
  159. return 'LLM模拟回复 {}'.format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
  160. if self.chat_service_type == ChatServiceType.OPENAI_COMPATIBLE:
  161. chat_completion = self.llm_client.chat.completions.create(
  162. messages=chat_config['messages'],
  163. model=self.model_name,
  164. )
  165. response = chat_completion.choices[0].message.content
  166. elif self.chat_service_type == ChatServiceType.COZE_CHAT:
  167. bot_user_id = 'dev_user'
  168. response = self.coze_client.create(
  169. chat_config['bot_id'], bot_user_id, chat_config['messages'],
  170. chat_config['custom_variables']
  171. )
  172. else:
  173. raise Exception('Unsupported chat service type: {}'.format(self.chat_service_type))
  174. return response
  175. if __name__ == "__main__":
  176. config = configs.get()
  177. logging_service.setup_root_logger()
  178. scheduler_logger = logging.getLogger('apscheduler')
  179. scheduler_logger.setLevel(logging.WARNING)
  180. use_aliyun_mq = config['use_aliyun_mq']
  181. # 初始化不同队列的后端
  182. if use_aliyun_mq:
  183. receive_queue = AliyunRocketMQQueueBackend(
  184. config['mq']['endpoints'],
  185. config['mq']['instance_id'],
  186. config['mq']['receive_topic'],
  187. has_consumer=True, has_producer=True,
  188. group_id=config['mq']['receive_group']
  189. )
  190. send_queue = AliyunRocketMQQueueBackend(
  191. config['mq']['endpoints'],
  192. config['mq']['instance_id'],
  193. config['mq']['send_topic'],
  194. has_consumer=False, has_producer=True
  195. )
  196. else:
  197. receive_queue = MemoryQueueBackend()
  198. send_queue = MemoryQueueBackend()
  199. human_queue = MemoryQueueBackend()
  200. # 初始化用户管理服务
  201. # FIXME(zhoutian): 如果不使用MySQL,此数据库配置非必须
  202. user_db_config = config['storage']['user']
  203. if config['debug_flags'].get('use_local_user_storage', False):
  204. user_manager = LocalUserManager()
  205. else:
  206. user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'])
  207. wecom_db_config = config['storage']['user_relation']
  208. user_relation_manager = MySQLUserRelationManager(
  209. user_db_config['mysql'], wecom_db_config['mysql'],
  210. config['storage']['staff']['table'],
  211. user_db_config['table'],
  212. wecom_db_config['table']['staff'],
  213. wecom_db_config['table']['relation'],
  214. wecom_db_config['table']['user']
  215. )
  216. # 创建Agent服务
  217. service = AgentService(
  218. receive_backend=receive_queue,
  219. send_backend=send_queue,
  220. human_backend=human_queue,
  221. user_manager=user_manager,
  222. user_relation_manager=user_relation_manager,
  223. chat_service_type=ChatServiceType.COZE_CHAT
  224. )
  225. # 只有企微场景需要主动发起
  226. service.setup_initiative_conversations({'second': '5,35'})
  227. process_thread = threading.Thread(target=service.process_messages)
  228. process_thread.start()
  229. console_input = True
  230. if not console_input:
  231. process_thread.join()
  232. sys.exit(0)
  233. message_id = 0
  234. while True:
  235. print("Input next message: ")
  236. text = sys.stdin.readline().strip()
  237. if not text:
  238. continue
  239. message_id += 1
  240. message = Message.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
  241. '7881302581935903','1688854492669990', text, int(time.time() * 1000)
  242. )
  243. message.msgId = message_id
  244. receive_queue.produce(message)
  245. time.sleep(0.1)
  246. process_thread.join()