agent_server.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import logging
  2. import sys
  3. import time
  4. from pathlib import Path
  5. # 获取当前文件的父目录的父目录(项目根目录)
  6. BASE_DIR = Path(__file__).resolve().parent.parent
  7. sys.path.insert(0, str(BASE_DIR)) # 将项目根目录添加到模块搜索路径
  8. from pqai_agent import configs
  9. from pqai_agent.agent_service import AgentService
  10. from pqai_agent.chat_service import ChatServiceType
  11. from pqai_agent.logging import logger, setup_root_logger
  12. from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
  13. from pqai_agent.message_queue_backend import AliyunRocketMQQueueBackend, MemoryQueueBackend
  14. from pqai_agent.push_service import PushTaskWorkerPool, PushScanThread
  15. from pqai_agent.user_manager import LocalUserManager, LocalUserRelationManager, MySQLUserManager, \
  16. MySQLUserRelationManager
  17. if __name__ == "__main__":
  18. config = configs.get()
  19. setup_root_logger()
  20. logger.warning("current env: {}".format(configs.get_env()))
  21. scheduler_logger = logging.getLogger('apscheduler')
  22. scheduler_logger.setLevel(logging.WARNING)
  23. use_aliyun_mq = config['debug_flags']['use_aliyun_mq']
  24. # 初始化不同队列的后端
  25. if use_aliyun_mq:
  26. # 实际只创建producer,consumer在工作线程中创建
  27. receive_queue = AliyunRocketMQQueueBackend(
  28. config['mq']['endpoints'],
  29. config['mq']['instance_id'],
  30. config['mq']['receive_topic'],
  31. has_consumer=False, has_producer=True,
  32. group_id=config['mq']['receive_group'],
  33. topic_type='FIFO'
  34. )
  35. if configs.get_env() == 'prod':
  36. send_queue = AliyunRocketMQQueueBackend(
  37. config['mq']['endpoints'],
  38. config['mq']['instance_id'],
  39. config['mq']['send_topic'],
  40. has_consumer=False, has_producer=True,
  41. topic_type='FIFO'
  42. )
  43. else:
  44. send_queue = MemoryQueueBackend()
  45. else:
  46. receive_queue = MemoryQueueBackend()
  47. send_queue = MemoryQueueBackend()
  48. human_queue = MemoryQueueBackend()
  49. # 初始化用户管理服务
  50. # FIXME(zhoutian): 如果不使用MySQL,此数据库配置非必须
  51. agent_db_config = config['database']['ai_agent']
  52. growth_db_config = config['database']['growth']
  53. user_db_config = config['storage']['user']
  54. staff_db_config = config['storage']['staff']
  55. wecom_db_config = config['storage']['user_relation']
  56. if config['debug_flags'].get('use_local_user_storage', False):
  57. user_manager = LocalUserManager()
  58. user_relation_manager = LocalUserRelationManager()
  59. else:
  60. user_manager = MySQLUserManager(agent_db_config, user_db_config['table'], staff_db_config['table'])
  61. user_relation_manager = MySQLUserRelationManager(
  62. agent_db_config, growth_db_config,
  63. staff_db_config['table'],
  64. user_db_config['table'],
  65. wecom_db_config['table']['staff'],
  66. wecom_db_config['table']['relation'],
  67. wecom_db_config['table']['user']
  68. )
  69. # 创建Agent服务
  70. service = AgentService(
  71. receive_backend=receive_queue,
  72. send_backend=send_queue,
  73. human_backend=human_queue,
  74. user_manager=user_manager,
  75. user_relation_manager=user_relation_manager,
  76. chat_service_type=ChatServiceType.OPENAI_COMPATIBLE
  77. )
  78. if not config['debug_flags'].get('console_input', False):
  79. service.start(blocking=True)
  80. sys.exit(0)
  81. else:
  82. service.start()
  83. message_id = 0
  84. while service.running:
  85. print("Input next message: ")
  86. text = sys.stdin.readline().strip()
  87. if not text:
  88. continue
  89. message_id += 1
  90. sender = '7881301903997433'
  91. receiver = '1688854974625870'
  92. if text in (MessageType.AGGREGATION_TRIGGER.name,
  93. MessageType.HUMAN_INTERVENTION_END.name):
  94. message = MqMessage.build(
  95. MessageType.__members__.get(text),
  96. MessageChannel.CORP_WECHAT,
  97. sender, receiver, None, int(time.time() * 1000))
  98. elif text == 'S_PUSH':
  99. service._check_initiative_conversations()
  100. continue
  101. else:
  102. message = MqMessage.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
  103. sender, receiver, text, int(time.time() * 1000)
  104. )
  105. message.msgId = message_id
  106. receive_queue.produce(message)
  107. time.sleep(0.1)