push_service.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import json
  2. import time
  3. import uuid
  4. from enum import Enum
  5. from concurrent.futures import ThreadPoolExecutor
  6. from threading import Thread
  7. from typing import Optional, Dict
  8. import rocketmq
  9. from rocketmq import ClientConfiguration, Credentials, SimpleConsumer, FilterExpression
  10. from pqai_agent import configs
  11. from pqai_agent.agents.message_push_agent import MessagePushAgent, DummyMessagePushAgent
  12. from pqai_agent.configs import apollo_config
  13. from pqai_agent.logging_service import logger
  14. from pqai_agent.message import MessageType
  15. class TaskType(Enum):
  16. GENERATE = "generate"
  17. SEND = "send"
  18. def generate_task_rmq_message(topic: str, staff_id: str, user_id: str, task_type: TaskType, content: Optional[str] = None) -> rocketmq.Message:
  19. msg = rocketmq.Message()
  20. msg.topic = topic
  21. msg.body = json.dumps({
  22. 'staff_id': staff_id,
  23. 'user_id': user_id,
  24. 'task_type': task_type.value,
  25. # FIXME: 需要支持多模态消息
  26. 'content': content or '',
  27. 'timestamp': int(time.time() * 1000),
  28. }, ensure_ascii=False).encode('utf-8')
  29. msg.tag = task_type.value
  30. return msg
  31. class PushScanThread:
  32. # PushScanThread实际可以是AgentService的一个函数,从AgentService中独立的主要考虑因素为Push后续可能有拆分和扩展
  33. def __init__(self, staff_id: str, agent_service: 'AgentService', mq_topic: str, mq_producer: rocketmq.Producer):
  34. self.staff_id = staff_id
  35. # 需要大量使用AgentService内部的成员
  36. self.service = agent_service
  37. self.rmq_topic = mq_topic
  38. self.rmq_producer = mq_producer
  39. def run(self):
  40. white_list_tags = set(apollo_config.get_json_value('agent_initiate_whitelist_tags'))
  41. first_initiate_tags = set(apollo_config.get_json_value('agent_first_initiate_whitelist_tags', []))
  42. # 合并白名单,减少配置成本
  43. white_list_tags.update(first_initiate_tags)
  44. for staff_user in self.service.user_relation_manager.list_staff_users(staff_id=self.staff_id):
  45. staff_id = staff_user['staff_id']
  46. user_id = staff_user['user_id']
  47. agent = self.service.get_agent_instance(staff_id, user_id)
  48. should_initiate = agent.should_initiate_conversation()
  49. user_tags = self.service.user_relation_manager.get_user_tags(user_id)
  50. if configs.get_env() != 'dev' and not white_list_tags.intersection(user_tags):
  51. should_initiate = False
  52. if should_initiate:
  53. logger.info(f"user[{user_id}], tags{user_tags}: generate a generation task for conversation initiation")
  54. rmq_msg = generate_task_rmq_message(self.rmq_topic, staff_id, user_id, TaskType.GENERATE)
  55. self.rmq_producer.send(rmq_msg)
  56. else:
  57. logger.debug(f"user[{user_id}], do not initiate conversation")
  58. class PushTaskWorkerPool:
  59. def __init__(self, agent_service: 'AgentService', mq_topic: str,
  60. mq_consumer: rocketmq.SimpleConsumer, mq_producer: rocketmq.Producer):
  61. self.agent_service = agent_service
  62. self.generate_executor = ThreadPoolExecutor(max_workers=5)
  63. self.send_executors = {}
  64. self.rmq_topic = mq_topic
  65. self.consumer = mq_consumer
  66. self.producer = mq_producer
  67. self.loop_thread = None
  68. self.is_generator_running = True
  69. self.generate_send_done = False # set by wait_to_finish
  70. self.no_more_generate_task = False # set by self
  71. def start(self):
  72. self.loop_thread = Thread(target=self.process_push_tasks)
  73. self.loop_thread.start()
  74. def process_push_tasks(self):
  75. while True:
  76. msgs = self.consumer.receive(1, 300)
  77. if not msgs:
  78. # 没有生成任务在执行且没有消息,才可退出
  79. if self.generate_send_done:
  80. if not self.no_more_generate_task:
  81. logger.debug("no message received, there should be no more generate task")
  82. self.no_more_generate_task = True
  83. continue
  84. else:
  85. if self.is_generator_running:
  86. logger.debug("Waiting for generator threads to finish")
  87. continue
  88. else:
  89. break
  90. else:
  91. continue
  92. msg = msgs[0]
  93. task = json.loads(msg.body.decode('utf-8'))
  94. logger.debug(f"recv message: {task}")
  95. if task['task_type'] == TaskType.GENERATE.value:
  96. self.generate_executor.submit(self.handle_generate_task, task, msg)
  97. elif task['task_type'] == TaskType.SEND.value:
  98. staff_id = task['staff_id']
  99. if staff_id not in self.send_executors:
  100. self.send_executors[staff_id] = ThreadPoolExecutor(max_workers=1)
  101. self.send_executors[staff_id].submit(self.handle_send_task, task, msg)
  102. else:
  103. logger.error(f"Unknown task type: {task['task_type']}")
  104. self.consumer.ack(msg)
  105. logger.info("PushGenerateWorkerPool stopped")
  106. def wait_to_finish(self):
  107. self.generate_send_done = True
  108. while not self.no_more_generate_task:
  109. #FIXME(zhoutian): condition variable should be used to replace time sleep
  110. time.sleep(1)
  111. self.generate_executor.shutdown(wait=True)
  112. self.is_generator_running = False
  113. self.loop_thread.join()
  114. def handle_send_task(self, task: Dict, msg: rocketmq.Message):
  115. try:
  116. staff_id = task['staff_id']
  117. user_id = task['user_id']
  118. agent = self.agent_service.get_agent_instance(staff_id, user_id)
  119. # 二次校验是否需要发送
  120. if not agent.should_initiate_conversation():
  121. logger.debug(f"user[{user_id}], do not initiate conversation")
  122. self.consumer.ack(msg)
  123. return
  124. content = task['content']
  125. recent_dialogue = agent.dialogue_history[-10:]
  126. agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist"))
  127. if len(recent_dialogue) < 2 or staff_id not in agent_voice_whitelist:
  128. message_type = MessageType.TEXT
  129. else:
  130. message_type = self.agent_service.response_type_detector.detect_type(
  131. recent_dialogue[:-1], recent_dialogue[-1], enable_random=True)
  132. response = agent.generate_response(content)
  133. if response:
  134. self.agent_service.send_response(staff_id, user_id, response, message_type, skip_check=True)
  135. else:
  136. logger.debug(f"agent[{staff_id}] generate empty response")
  137. self.consumer.ack(msg)
  138. except Exception as e:
  139. logger.error(f"Error processing message sending: {e}")
  140. self.consumer.ack(msg)
  141. def handle_generate_task(self, task: Dict, msg: rocketmq.Message):
  142. try:
  143. staff_id = task['staff_id']
  144. user_id = task['user_id']
  145. main_agent = self.agent_service.get_agent_instance(staff_id, user_id)
  146. push_agent = DummyMessagePushAgent()
  147. message_to_user = push_agent.generate_message(
  148. context=main_agent.get_prompt_context(None),
  149. dialogue_history=main_agent.dialogue_history
  150. )
  151. rmq_message = generate_task_rmq_message(self.rmq_topic, staff_id, user_id, TaskType.SEND, message_to_user)
  152. logger.debug(f"send message: {rmq_message.body.decode('utf-8')}")
  153. self.producer.send(rmq_message)
  154. self.consumer.ack(msg)
  155. except Exception as e:
  156. logger.error(f"Error processing message generation: {e}")
  157. # FIXME: 是否需要ACK
  158. self.consumer.ack(msg)