123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- import logging
- import sys
- import time
- from pqai_agent import configs, logging_service
- from pqai_agent.agent_service import AgentService
- from pqai_agent.chat_service import ChatServiceType
- from pqai_agent.logging_service import logger
- from pqai_agent.mq_message import MessageType, MqMessage, MessageChannel
- from pqai_agent.message_queue_backend import AliyunRocketMQQueueBackend, MemoryQueueBackend
- from pqai_agent.push_service import PushTaskWorkerPool, PushScanThread
- from pqai_agent.user_manager import LocalUserManager, LocalUserRelationManager, MySQLUserManager, \
- MySQLUserRelationManager
- if __name__ == "__main__":
- config = configs.get()
- logging_service.setup_root_logger()
- logger.warning("current env: {}".format(configs.get_env()))
- scheduler_logger = logging.getLogger('apscheduler')
- scheduler_logger.setLevel(logging.WARNING)
- use_aliyun_mq = config['debug_flags']['use_aliyun_mq']
- # 初始化不同队列的后端
- if use_aliyun_mq:
- # 实际只创建producer,consumer在工作线程中创建
- receive_queue = AliyunRocketMQQueueBackend(
- config['mq']['endpoints'],
- config['mq']['instance_id'],
- config['mq']['receive_topic'],
- has_consumer=False, has_producer=True,
- group_id=config['mq']['receive_group'],
- topic_type='FIFO'
- )
- if configs.get_env() == 'prod':
- send_queue = AliyunRocketMQQueueBackend(
- config['mq']['endpoints'],
- config['mq']['instance_id'],
- config['mq']['send_topic'],
- has_consumer=False, has_producer=True,
- topic_type='FIFO'
- )
- else:
- send_queue = MemoryQueueBackend()
- else:
- receive_queue = MemoryQueueBackend()
- send_queue = MemoryQueueBackend()
- human_queue = MemoryQueueBackend()
- # 初始化用户管理服务
- # FIXME(zhoutian): 如果不使用MySQL,此数据库配置非必须
- user_db_config = config['storage']['user']
- staff_db_config = config['storage']['staff']
- wecom_db_config = config['storage']['user_relation']
- if config['debug_flags'].get('use_local_user_storage', False):
- user_manager = LocalUserManager()
- user_relation_manager = LocalUserRelationManager()
- else:
- user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
- user_relation_manager = MySQLUserRelationManager(
- user_db_config['mysql'], wecom_db_config['mysql'],
- config['storage']['staff']['table'],
- user_db_config['table'],
- wecom_db_config['table']['staff'],
- wecom_db_config['table']['relation'],
- wecom_db_config['table']['user']
- )
- # 创建Agent服务
- service = AgentService(
- receive_backend=receive_queue,
- send_backend=send_queue,
- human_backend=human_queue,
- user_manager=user_manager,
- user_relation_manager=user_relation_manager,
- chat_service_type=ChatServiceType.COZE_CHAT
- )
- if not config['debug_flags'].get('console_input', False):
- service.start(blocking=True)
- sys.exit(0)
- else:
- service.start()
- message_id = 0
- while service.running:
- print("Input next message: ")
- text = sys.stdin.readline().strip()
- if not text:
- continue
- message_id += 1
- sender = '7881301903997433'
- receiver = '1688855931724582'
- if text in (MessageType.AGGREGATION_TRIGGER.name,
- MessageType.HUMAN_INTERVENTION_END.name):
- message = MqMessage.build(
- MessageType.__members__.get(text),
- MessageChannel.CORP_WECHAT,
- sender, receiver, None, int(time.time() * 1000))
- else:
- message = MqMessage.build(MessageType.TEXT, MessageChannel.CORP_WECHAT,
- sender, receiver, text, int(time.time() * 1000)
- )
- message.msgId = message_id
- receive_queue.produce(message)
- time.sleep(0.1)
|