agent_service.py 8.0 KB


  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import time
  5. from typing import Dict, List, Tuple, Any
  6. import logging
  7. from datetime import datetime, timedelta
  8. import apscheduler.triggers.cron
  9. from apscheduler.schedulers.background import BackgroundScheduler
  10. import global_flags
  11. from dialogue_manager import DialogueManager, DialogueState
  12. from user_manager import UserManager, LocalUserManager
  13. from openai import OpenAI
  14. from message_queue_backend import MessageQueueBackend, MemoryQueueBackend
  15. from user_profile_extractor import UserProfileExtractor
  16. import threading
  17. from message import MessageType
  18. from logging_service import ColoredFormatter
  19. class AgentService:
  20. def __init__(
  21. self,
  22. receive_backend: MessageQueueBackend,
  23. send_backend: MessageQueueBackend,
  24. human_backend: MessageQueueBackend,
  25. user_manager: UserManager
  26. ):
  27. self.receive_queue = receive_backend
  28. self.send_queue = send_backend
  29. self.human_queue = human_backend
  30. # 核心服务模块
  31. self.user_manager = user_manager
  32. self.user_profile_extractor = UserProfileExtractor()
  33. self.agent_registry: Dict[str, DialogueManager] = {}
  34. self.llm_client = OpenAI(
  35. api_key='5e275c38-44fd-415f-abcf-4b59f6377f72',
  36. base_url="https://ark.cn-beijing.volces.com/api/v3"
  37. )
  38. self.model_name = "ep-20250213194558-rrmr2" # DeepSeek on Volces
  39. # 定时任务调度器
  40. self.scheduler = BackgroundScheduler()
  41. self.scheduler.start()
  42. self._setup_initiative_conversations()
  43. def _setup_initiative_conversations(self):
  44. self.scheduler.add_job(
  45. self._check_initiative_conversations,
  46. apscheduler.triggers.cron.CronTrigger(second='5,35')
  47. )
  48. def _get_agent_instance(self, user_id: str) -> DialogueManager:
  49. """获取用户Agent实例"""
  50. if user_id not in self.agent_registry:
  51. self.agent_registry[user_id] = DialogueManager(
  52. user_id, self.user_manager)
  53. return self.agent_registry[user_id]
  54. def process_messages(self):
  55. """持续处理接收队列消息"""
  56. while True:
  57. message = self.receive_queue.consume()
  58. if message:
  59. self._process_single_message(message)
  60. time.sleep(1) # 避免CPU空转
  61. def _update_user_profile(self, user_id, user_profile, message: str):
  62. profile_to_update = self.user_profile_extractor.extract_profile_info(user_profile, message)
  63. if not profile_to_update:
  64. logging.debug("user_id: {}, no profile info extracted".format(user_id))
  65. return
  66. logging.warning("update user profile: {}".format(profile_to_update))
  67. merged_profile = self.user_profile_extractor.merge_profile_info(user_profile, profile_to_update)
  68. self.user_manager.save_user_profile(user_id, merged_profile)
  69. return merged_profile
  70. def _schedule_aggregation_trigger(self, user_id: str, delay_sec: int):
  71. logging.debug("user: {}, schedule trigger message after {} seconds".format(user_id, delay_sec))
  72. message = {
  73. 'user_id': user_id,
  74. 'type': MessageType.AGGREGATION_TRIGGER,
  75. 'text': None,
  76. 'timestamp': int(time.time() * 1000) + delay_sec * 1000
  77. }
  78. self.scheduler.add_job(lambda: self.receive_queue.produce(message),
  79. 'date',
  80. run_date=datetime.now() + timedelta(seconds=delay_sec))
  81. def _process_single_message(self, message: Dict):
  82. user_id = message['user_id']
  83. message_text = message.get('text', None)
  84. # 获取用户信息和Agent实例
  85. user_profile = self.user_manager.get_user_profile(user_id)
  86. agent = self._get_agent_instance(user_id)
  87. # 更新对话状态
  88. logging.debug("process message: {}".format(message))
  89. dialogue_state, message_text = agent.update_state(message)
  90. logging.debug("user: {}, next state: {}".format(user_id, dialogue_state))
  91. # 根据状态路由消息
  92. if agent.is_in_human_intervention():
  93. self._route_to_human_intervention(user_id, message_text, dialogue_state)
  94. elif dialogue_state == DialogueState.MESSAGE_AGGREGATING:
  95. if message['type'] != MessageType.AGGREGATION_TRIGGER:
  96. # 产生一个触发器,但是不能由触发器递归产生
  97. logging.debug("user: {}, waiting next message for aggregation".format(user_id))
  98. self._schedule_aggregation_trigger(user_id, agent.message_aggregation_sec)
  99. return
  100. else:
  101. # 先更新用户画像再处理回复
  102. self._update_user_profile(user_id, user_profile, message_text)
  103. self._process_llm_response(user_id, agent, message_text)
  104. def _route_to_human_intervention(self, user_id: str, user_message: str, state: DialogueState):
  105. """路由到人工干预"""
  106. self.human_queue.produce({
  107. 'user_id': user_id,
  108. 'state': state,
  109. 'timestamp': datetime.now().isoformat()
  110. })
  111. def _process_llm_response(self, user_id: str, agent: DialogueManager,
  112. user_message: str):
  113. """处理LLM响应"""
  114. messages = agent.make_llm_messages(user_message)
  115. logging.debug(messages)
  116. llm_response = self._call_llm_api(messages)
  117. if response := agent.generate_response(llm_response):
  118. logging.warning("user: {}, response: {}".format(user_id, response))
  119. self.send_queue.produce({
  120. 'user_id': user_id,
  121. 'text': response,
  122. })
  123. def _check_initiative_conversations(self):
  124. """定时检查主动发起对话"""
  125. for user_id in self.user_manager.list_all_users():
  126. agent = self._get_agent_instance(user_id)
  127. should_initiate = agent.should_initiate_conversation()
  128. if should_initiate:
  129. logging.warning("user: {}, initiate conversation".format(user_id))
  130. self._process_llm_response(user_id, agent, None)
  131. else:
  132. logging.debug("user: {}, do not initiate conversation".format(user_id))
  133. def _call_llm_api(self, messages: List[Dict]) -> str:
  134. if global_flags.DISABLE_LLM_API_CALL:
  135. return 'LLM模拟回复'
  136. chat_completion = self.llm_client.chat.completions.create(
  137. messages=messages,
  138. model=self.model_name,
  139. )
  140. response = chat_completion.choices[0].message.content
  141. return response
  142. if __name__ == "__main__":
  143. logging.getLogger().setLevel(logging.DEBUG)
  144. console_handler = logging.StreamHandler()
  145. console_handler.setLevel(logging.WARNING)
  146. formatter = ColoredFormatter(
  147. '%(asctime)s - %(funcName)s[%(lineno)d] - %(levelname)s - %(message)s'
  148. )
  149. console_handler.setFormatter(formatter)
  150. root_logger = logging.getLogger()
  151. root_logger.handlers.clear()
  152. root_logger.addHandler(console_handler)
  153. scheduler_logger = logging.getLogger('apscheduler')
  154. scheduler_logger.setLevel(logging.WARNING)
  155. # 初始化不同队列的后端
  156. receive_queue = MemoryQueueBackend()
  157. send_queue = MemoryQueueBackend()
  158. human_queue = MemoryQueueBackend()
  159. # 初始化用户管理服务
  160. user_manager = LocalUserManager()
  161. global_flags.DISABLE_LLM_API_CALL = False
  162. # 创建Agent服务
  163. service = AgentService(
  164. receive_backend=receive_queue,
  165. send_backend=send_queue,
  166. human_backend=human_queue,
  167. user_manager=user_manager
  168. )
  169. process_thread = threading.Thread(target=service.process_messages)
  170. process_thread.start()
  171. while True:
  172. print("Input next message: ")
  173. message = sys.stdin.readline().strip()
  174. message_dict = {
  175. "user_id": "user_id_1",
  176. "type": MessageType.TEXT,
  177. "text": message,
  178. "timestamp": int(time.time() * 1000)
  179. }
  180. if message:
  181. receive_queue.produce(message_dict)
  182. time.sleep(0.1)