agent_server.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import logging
  2. import sys
  3. import time
  4. from pqai_agent import configs, logging_service
  5. from pqai_agent.agent_service import AgentService
  6. from pqai_agent.chat_service import ChatServiceType
  7. from pqai_agent.logging_service import logger
  8. from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
  9. from pqai_agent.message_queue_backend import AliyunRocketMQQueueBackend, MemoryQueueBackend
  10. from pqai_agent.push_service import PushTaskWorkerPool, PushScanThread
  11. from pqai_agent.user_manager import LocalUserManager, LocalUserRelationManager, MySQLUserManager, \
  12. MySQLUserRelationManager
  13. if __name__ == "__main__":
  14. config = configs.get()
  15. logging_service.setup_root_logger()
  16. logger.warning("current env: {}".format(configs.get_env()))
  17. scheduler_logger = logging.getLogger('apscheduler')
  18. scheduler_logger.setLevel(logging.WARNING)
  19. use_aliyun_mq = config['debug_flags']['use_aliyun_mq']
  20. # 初始化不同队列的后端
  21. if use_aliyun_mq:
  22. # 实际只创建producer,consumer在工作线程中创建
  23. receive_queue = AliyunRocketMQQueueBackend(
  24. config['mq']['endpoints'],
  25. config['mq']['instance_id'],
  26. config['mq']['receive_topic'],
  27. has_consumer=False, has_producer=True,
  28. group_id=config['mq']['receive_group'],
  29. topic_type='FIFO'
  30. )
  31. if configs.get_env() == 'prod':
  32. send_queue = AliyunRocketMQQueueBackend(
  33. config['mq']['endpoints'],
  34. config['mq']['instance_id'],
  35. config['mq']['send_topic'],
  36. has_consumer=False, has_producer=True,
  37. topic_type='FIFO'
  38. )
  39. else:
  40. send_queue = MemoryQueueBackend()
  41. else:
  42. receive_queue = MemoryQueueBackend()
  43. send_queue = MemoryQueueBackend()
  44. human_queue = MemoryQueueBackend()
  45. # 初始化用户管理服务
  46. # FIXME(zhoutian): 如果不使用MySQL,此数据库配置非必须
  47. agent_db_config = config['database']['ai_agent']
  48. growth_db_config = config['database']['growth']
  49. user_db_config = config['storage']['user']
  50. staff_db_config = config['storage']['staff']
  51. wecom_db_config = config['storage']['user_relation']
  52. if config['debug_flags'].get('use_local_user_storage', False):
  53. user_manager = LocalUserManager()
  54. user_relation_manager = LocalUserRelationManager()
  55. else:
  56. user_manager = MySQLUserManager(agent_db_config, user_db_config['table'], staff_db_config['table'])
  57. user_relation_manager = MySQLUserRelationManager(
  58. agent_db_config, growth_db_config,
  59. staff_db_config['table'],
  60. user_db_config['table'],
  61. wecom_db_config['table']['staff'],
  62. wecom_db_config['table']['relation'],
  63. wecom_db_config['table']['user']
  64. )
  65. # 创建Agent服务
  66. service = AgentService(
  67. receive_backend=receive_queue,
  68. send_backend=send_queue,
  69. human_backend=human_queue,
  70. user_manager=user_manager,
  71. user_relation_manager=user_relation_manager,
  72. chat_service_type=ChatServiceType.COZE_CHAT
  73. )
  74. if not config['debug_flags'].get('console_input', False):
  75. service.start(blocking=True)
  76. sys.exit(0)
  77. else:
  78. service.start()
  79. message_id = 0
  80. while service.running:
  81. print("Input next message: ")
  82. text = sys.stdin.readline().strip()
  83. if not text:
  84. continue
  85. message_id += 1
  86. sender = '7881301903997433'
  87. receiver = '1688855931724582'
  88. if text in (MessageType.AGGREGATION_TRIGGER.name,
  89. MessageType.HUMAN_INTERVENTION_END.name):
  90. message = MqMessage.build(
  91. MessageType.__members__.get(text),
  92. MessageChannel.CORP_WECHAT,
  93. sender, receiver, None, int(time.time() * 1000))
  94. else:
  95. message = MqMessage.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
  96. sender, receiver, text, int(time.time() * 1000)
  97. )
  98. message.msgId = message_id
  99. receive_queue.produce(message)
  100. time.sleep(0.1)