push_service.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import json
  2. import time
  3. import traceback
  4. import uuid
  5. from datetime import datetime
  6. from enum import Enum
  7. from concurrent.futures import ThreadPoolExecutor
  8. from threading import Thread
  9. from typing import Optional, Dict, List
  10. import rocketmq
  11. from rocketmq import ClientConfiguration, Credentials, SimpleConsumer, FilterExpression
  12. from pqai_agent import configs
  13. from pqai_agent.agents.message_push_agent import MessagePushAgent, DummyMessagePushAgent
  14. from pqai_agent.configs import apollo_config
  15. from pqai_agent.data_models.agent_push_record import AgentPushRecord
  16. from pqai_agent.logging_service import logger
  17. from pqai_agent.mq_message import MessageType
  18. from pqai_agent.utils.agent_abtest_utils import get_agent_abtest_config
  19. class TaskType(Enum):
  20. GENERATE = "generate"
  21. SEND = "send"
  22. def generate_task_rmq_message(topic: str, staff_id: str, user_id: str, task_type: TaskType, content: Optional[str] = None) -> rocketmq.Message:
  23. msg = rocketmq.Message()
  24. msg.topic = topic
  25. msg.body = json.dumps({
  26. 'staff_id': staff_id,
  27. 'user_id': user_id,
  28. 'task_type': task_type.value,
  29. # NOTE:通过传入JSON支持多模态消息
  30. 'content': content or '',
  31. 'timestamp': int(time.time() * 1000),
  32. }, ensure_ascii=False).encode('utf-8')
  33. msg.tag = task_type.value
  34. return msg
  35. class PushScanThread:
  36. # PushScanThread实际可以是AgentService的一个函数,从AgentService中独立的主要考虑因素为Push后续可能有拆分和扩展
  37. def __init__(self, staff_id: str, agent_service: 'AgentService', mq_topic: str, mq_producer: rocketmq.Producer):
  38. self.staff_id = staff_id
  39. # 需要大量使用AgentService内部的成员
  40. self.service = agent_service
  41. self.rmq_topic = mq_topic
  42. self.rmq_producer = mq_producer
  43. def run(self):
  44. white_list_tags = set(apollo_config.get_json_value('agent_initiate_whitelist_tags', []))
  45. first_initiate_tags = set(apollo_config.get_json_value('agent_first_initiate_whitelist_tags', []))
  46. # 合并白名单,减少配置成本
  47. white_list_tags.update(first_initiate_tags)
  48. for staff_user in self.service.user_relation_manager.list_staff_users(staff_id=self.staff_id):
  49. staff_id = staff_user['staff_id']
  50. user_id = staff_user['user_id']
  51. agent = self.service.get_agent_instance(staff_id, user_id)
  52. should_initiate = agent.should_initiate_conversation()
  53. user_tags = self.service.user_relation_manager.get_user_tags(user_id)
  54. if configs.get_env() != 'dev' and not white_list_tags.intersection(user_tags):
  55. should_initiate = False
  56. if should_initiate:
  57. logger.info(f"user[{user_id}], tags{user_tags}: generate a generation task for conversation initiation")
  58. rmq_msg = generate_task_rmq_message(self.rmq_topic, staff_id, user_id, TaskType.GENERATE)
  59. self.rmq_producer.send(rmq_msg)
  60. else:
  61. logger.debug(f"user[{user_id}], do not initiate conversation")
  62. class PushTaskWorkerPool:
  63. def __init__(self, agent_service: 'AgentService', mq_topic: str,
  64. mq_consumer: rocketmq.SimpleConsumer, mq_producer: rocketmq.Producer):
  65. self.agent_service = agent_service
  66. max_workers = configs.get()['system'].get('push_task_workers', 5)
  67. self.generate_executor = ThreadPoolExecutor(max_workers=max_workers)
  68. self.send_executors = {}
  69. self.rmq_topic = mq_topic
  70. self.consumer = mq_consumer
  71. self.producer = mq_producer
  72. self.loop_thread = None
  73. self.is_generator_running = True
  74. self.generate_send_done = False # set by wait_to_finish
  75. self.no_more_generate_task = False # set by self
  76. def start(self):
  77. self.loop_thread = Thread(target=self.process_push_tasks)
  78. self.loop_thread.start()
  79. def process_push_tasks(self):
  80. # RMQ consumer疑似有bug,创建后立即消费可能报NPE
  81. time.sleep(1)
  82. while True:
  83. msgs = self.consumer.receive(1, 300)
  84. if not msgs:
  85. # 没有生成任务在执行且没有消息,才可退出
  86. if self.generate_send_done:
  87. if not self.no_more_generate_task:
  88. logger.debug("no message received, there should be no more generate task")
  89. self.no_more_generate_task = True
  90. continue
  91. else:
  92. if self.is_generator_running:
  93. logger.debug("Waiting for generator threads to finish")
  94. continue
  95. else:
  96. break
  97. else:
  98. continue
  99. msg = msgs[0]
  100. task = json.loads(msg.body.decode('utf-8'))
  101. msg_time = datetime.fromtimestamp(task['timestamp'] / 1000).strftime("%Y-%m-%d %H:%M:%S")
  102. logger.debug(f"recv message:{msg_time} - {task}")
  103. if task['task_type'] == TaskType.GENERATE.value:
  104. self.generate_executor.submit(self.handle_generate_task, task, msg)
  105. elif task['task_type'] == TaskType.SEND.value:
  106. staff_id = task['staff_id']
  107. if staff_id not in self.send_executors:
  108. self.send_executors[staff_id] = ThreadPoolExecutor(max_workers=1)
  109. self.send_executors[staff_id].submit(self.handle_send_task, task, msg)
  110. else:
  111. logger.error(f"Unknown task type: {task['task_type']}")
  112. self.consumer.ack(msg)
  113. logger.info("PushGenerateWorkerPool stopped")
  114. def wait_to_finish(self):
  115. self.generate_send_done = True
  116. while not self.no_more_generate_task:
  117. #FIXME(zhoutian): condition variable should be used to replace time sleep
  118. time.sleep(1)
  119. self.generate_executor.shutdown(wait=True)
  120. self.is_generator_running = False
  121. self.loop_thread.join()
  122. def handle_send_task(self, task: Dict, msg: rocketmq.Message):
  123. try:
  124. staff_id = task['staff_id']
  125. user_id = task['user_id']
  126. agent = self.agent_service.get_agent_instance(staff_id, user_id)
  127. # 二次校验是否需要发送
  128. if not agent.should_initiate_conversation():
  129. logger.debug(f"user[{user_id}], do not initiate conversation")
  130. self.consumer.ack(msg)
  131. return
  132. contents: List[Dict] = json.loads(task['content'])
  133. if not contents:
  134. logger.debug(f"staff[{staff_id}], user[{user_id}]: empty content, do not send")
  135. self.consumer.ack(msg)
  136. return
  137. recent_dialogue = agent.dialogue_history[-10:]
  138. agent_voice_whitelist = set(apollo_config.get_json_value("agent_voice_whitelist", []))
  139. messages_to_send = []
  140. current_ts = int(time.time())
  141. for item in contents:
  142. item["timestamp"] = current_ts * 1000
  143. if item["type"] == "text":
  144. if staff_id not in agent_voice_whitelist:
  145. message_type = MessageType.TEXT
  146. else:
  147. message_type = self.agent_service.response_type_detector.detect_type(
  148. recent_dialogue, item, enable_random=True)
  149. response = agent.generate_response(item["content"])
  150. if response:
  151. messages_to_send.append({'type': message_type, 'content': response})
  152. else:
  153. message_type = MessageType.from_str(item["type"])
  154. response = agent.generate_multimodal_response(item)
  155. if response:
  156. item["type"] = message_type
  157. messages_to_send.append(item)
  158. with self.agent_service.agent_db_session_maker() as session:
  159. msg_list = [{"type": msg["type"].value, "content": msg["content"]} for msg in messages_to_send]
  160. record = AgentPushRecord(staff_id=staff_id, user_id=user_id,
  161. content=json.dumps(msg_list, ensure_ascii=False),
  162. timestamp=current_ts)
  163. session.add(record)
  164. session.commit()
  165. if messages_to_send:
  166. for response in messages_to_send:
  167. self.agent_service.send_multimodal_response(staff_id, user_id, response, skip_check=True)
  168. agent.update_last_active_interaction_time(current_ts)
  169. else:
  170. logger.debug(f"staff[{staff_id}], user[{user_id}]: generate empty response")
  171. self.consumer.ack(msg)
  172. except Exception as e:
  173. fmt_exc = traceback.format_exc()
  174. logger.error(f"Error processing message sending: {e}, {fmt_exc}")
  175. self.consumer.ack(msg)
  176. def handle_generate_task(self, task: Dict, msg: rocketmq.Message):
  177. try:
  178. staff_id = task['staff_id']
  179. user_id = task['user_id']
  180. main_agent = self.agent_service.get_agent_instance(staff_id, user_id)
  181. agent_config = get_agent_abtest_config('push', user_id,
  182. self.agent_service.service_module_manager,
  183. self.agent_service.agent_config_manager)
  184. if agent_config:
  185. push_agent = MessagePushAgent(model=agent_config.execution_model,
  186. system_prompt=agent_config.system_prompt,
  187. tools=None)
  188. query_prompt_template = agent_config.task_prompt
  189. else:
  190. push_agent = MessagePushAgent()
  191. query_prompt_template = None
  192. message_to_user = push_agent.generate_message(
  193. context=main_agent.get_prompt_context(None),
  194. dialogue_history=self.agent_service.history_dialogue_db.get_dialogue_history_backward(
  195. staff_id, user_id, main_agent.last_interaction_time_ms, limit=100
  196. ),
  197. query_prompt_template=query_prompt_template
  198. )
  199. if message_to_user:
  200. rmq_message = generate_task_rmq_message(
  201. self.rmq_topic, staff_id, user_id, TaskType.SEND, json.dumps(message_to_user))
  202. self.producer.send(rmq_message)
  203. else:
  204. logger.info(f"staff[{staff_id}], user[{user_id}]: no push message generated")
  205. self.consumer.ack(msg)
  206. except Exception as e:
  207. fmt_exc = traceback.format_exc()
  208. logger.error(f"Error processing message generation: {e}, {fmt_exc}")
  209. # FIXME: 是否需要ACK
  210. self.consumer.ack(msg)