|
- #! /usr/bin/env python
- # -*- coding: utf-8 -*-
- # vim:fenc=utf-8
- import time
- import logging
- import werkzeug.exceptions
- from flask import Flask, request, jsonify
- from argparse import ArgumentParser
- from sqlalchemy.orm import sessionmaker
- from pqai_agent import configs
- from pqai_agent import logging_service, chat_service, prompt_templates
- from pqai_agent.agents.message_reply_agent import MessageReplyAgent
- from pqai_agent.data_models.agent_configuration import AgentConfiguration
- from pqai_agent.data_models.service_module import ServiceModule
- from pqai_agent.history_dialogue_service import HistoryDialogueService
- from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
- from pqai_agent.utils.db_utils import create_sql_engine
- from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
- from pqai_agent_server.const import AgentApiConst
- from pqai_agent_server.models import MySQLSessionManager
- from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
- from pqai_agent_server.utils import (
- run_extractor_prompt,
- run_chat_prompt,
- run_response_type_prompt,
- )
- app = Flask('agent_api_server')
- logger = logging_service.logger
- const = AgentApiConst()
- @app.route('/api/listStaffs', methods=['GET'])
- def list_staffs():
- staff_data = app.user_relation_manager.list_staffs()
- return wrap_response(200, data=staff_data)
- @app.route('/api/getStaffProfile', methods=['GET'])
- def get_staff_profile():
- staff_id = request.args['staff_id']
- profile = app.user_manager.get_staff_profile(staff_id)
- if not profile:
- return wrap_response(404, msg='staff not found')
- else:
- return wrap_response(200, data=profile)
- @app.route('/api/getUserProfile', methods=['GET'])
- def get_user_profile():
- user_id = request.args['user_id']
- profile = app.user_manager.get_user_profile(user_id)
- if not profile:
- resp = {
- 'code': 404,
- 'msg': 'user not found'
- }
- else:
- resp = {
- 'code': 200,
- 'msg': 'success',
- 'data': profile
- }
- return jsonify(resp)
- @app.route('/api/listUsers', methods=['GET'])
- def list_users():
- user_name = request.args.get('user_name', None)
- user_union_id = request.args.get('user_union_id', None)
- if not user_name and not user_union_id:
- resp = {
- 'code': 400,
- 'msg': 'user_name or user_union_id is required'
- }
- return jsonify(resp)
- data = app.user_manager.list_users(user_name=user_name, user_union_id=user_union_id)
- return jsonify({'code': 200, 'data': data})
- @app.route('/api/getDialogueHistory', methods=['GET'])
- def get_dialogue_history():
- staff_id = request.args['staff_id']
- user_id = request.args['user_id']
- recent_minutes = int(request.args.get('recent_minutes', 1440))
- dialogue_history = app.history_dialogue_service.get_dialogue_history(staff_id, user_id, recent_minutes)
- return jsonify({'code': 200, 'data': dialogue_history})
- @app.route('/api/listModels', methods=['GET'])
- def list_models():
- models = [
- {
- 'model_type': 'openai_compatible',
- 'model_name': chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
- 'display_name': 'DeepSeek V3 on 火山'
- },
- {
- 'model_type': 'openai_compatible',
- 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
- 'display_name': '豆包Pro 32K'
- },
- {
- 'model_type': 'openai_compatible',
- 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
- 'display_name': '豆包Pro 1.5'
- },
- {
- 'model_type': 'openai_compatible',
- 'model_name': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH,
- 'display_name': 'DeepSeek V3联网 on 火山'
- },
- {
- 'model_type': 'openai_compatible',
- 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
- 'display_name': '豆包1.5视觉理解Pro'
- },
- ]
- return wrap_response(200, data=models)
- @app.route('/api/listScenes', methods=['GET'])
- def list_scenes():
- scenes = [
- {'scene': 'greeting', 'display_name': '问候'},
- {'scene': 'chitchat', 'display_name': '闲聊'},
- {'scene': 'profile_extractor', 'display_name': '画像提取'},
- {'scene': 'response_type_detector', 'display_name': '回复模态判断'},
- {'scene': 'custom_debugging', 'display_name': '自定义调试场景'}
- ]
- return wrap_response(200, data=scenes)
- @app.route('/api/getBasePrompt', methods=['GET'])
- def get_base_prompt():
- scene = request.args['scene']
- prompt_map = {
- 'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
- 'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
- 'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT,
- 'response_type_detector': prompt_templates.RESPONSE_TYPE_DETECT_PROMPT,
- 'custom_debugging': '',
- }
- model_map = {
- 'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
- 'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
- 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
- 'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
- 'custom_debugging': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH
- }
- if scene not in prompt_map:
- return wrap_response(404, msg='scene not found')
- data = {
- 'model_name': model_map[scene],
- 'content': prompt_map[scene]
- }
- return wrap_response(200, data=data)
- @app.route('/api/runPrompt', methods=['POST'])
- def run_prompt():
- try:
- req_data = request.json
- logger.debug(req_data)
- scene = req_data['scene']
- if scene == 'profile_extractor':
- response = run_extractor_prompt(req_data)
- return wrap_response(200, data=response)
- elif scene == 'response_type_detector':
- response = run_response_type_prompt(req_data)
- return wrap_response(200, data=response.choices[0].message.content)
- else:
- response = run_chat_prompt(req_data)
- return wrap_response(200, data=response.choices[0].message.content)
- except Exception as e:
- logger.error(e)
- return wrap_response(500, msg='Error: {}'.format(e))
- @app.route('/api/formatForPrompt', methods=['POST'])
- def format_data_for_prompt():
- try:
- req_data = request.json
- content = req_data['content']
- format_type = req_data['format_type']
- if format_type == 'staff_profile':
- if not isinstance(content, dict):
- return wrap_response(400, msg='staff_profile should be a dict')
- response = format_agent_profile(content)
- elif format_type == 'user_profile':
- if not isinstance(content, dict):
- return wrap_response(400, msg='user_profile should be a dict')
- response = format_user_profile(content)
- elif format_type == 'dialogue':
- if not isinstance(content, list):
- return wrap_response(400, msg='dialogue should be a list')
- from pqai_agent_server.utils.prompt_util import format_dialogue_history
- response = format_dialogue_history(content)
- else:
- return wrap_response(400, msg='Invalid format_type')
- return wrap_response(200, data=response)
- except Exception as e:
- logger.error(e)
- return wrap_response(500, msg='Error: {}'.format(e))
- @app.route("/api/healthCheck", methods=["GET"])
- def health_check():
- return wrap_response(200, msg="OK")
- @app.route("/api/getStaffSessionSummary", methods=["GET"])
- def get_staff_session_summary():
- staff_id = request.args.get("staff_id")
- status = request.args.get("status", const.DEFAULT_STAFF_STATUS)
- page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
- page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
- # check params
- try:
- page_id = int(page_id)
- page_size = int(page_size)
- status = int(status)
- except Exception as e:
- return wrap_response(404, msg="Invalid parameter: {}".format(e))
- staff_session_summary = app.session_manager.get_staff_sessions_summary(
- staff_id, page_id, page_size, status
- )
- if not staff_session_summary:
- return wrap_response(404, msg="staff not found")
- else:
- return wrap_response(200, data=staff_session_summary)
- @app.route("/api/getStaffSessionList", methods=["GET"])
- def get_staff_session_list():
- staff_id = request.args.get("staff_id")
- if not staff_id:
- return wrap_response(404, msg="staff_id is required")
- page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
- page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
- # check params
- try:
- page_id = int(page_id)
- page_size = int(page_size)
- except Exception as e:
- return wrap_response(404, msg="Invalid parameter: {}".format(e))
- staff_session_list = app.session_manager.get_staff_session_list(staff_id, page_id, page_size)
- if not staff_session_list:
- return wrap_response(404, msg="staff not found")
- return wrap_response(200, data=staff_session_list)
- @app.route("/api/getStaffList", methods=["GET"])
- def get_staff_list():
- page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
- page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
- # check params
- try:
- page_id = int(page_id)
- page_size = int(page_size)
- except Exception as e:
- return wrap_response(404, msg="Invalid parameter: {}".format(e))
- staff_list = app.user_manager.get_staff_list(page_id, page_size)
- if not staff_list:
- return wrap_response(404, msg="staff not found")
- return wrap_response(200, data=staff_list)
- @app.route("/api/getConversationList", methods=["GET"])
- def get_conversation_list():
- """
- 获取staff && user 私聊对话列表
- :return:
- """
- staff_id = request.args.get("staff_id")
- user_id = request.args.get("user_id")
- if not staff_id or not user_id:
- return wrap_response(404, msg="staff_id and user_id are required")
- page = request.args.get("page")
- response = app.session_manager.get_conversation_list(staff_id, user_id, page, const.DEFAULT_CONVERSATION_SIZE)
- return wrap_response(200, data=response)
- @app.route("/api/sendMessage", methods=["POST"])
- def send_message():
- return wrap_response(200, msg="暂不实现功能")
- @app.route("/api/quitHumanInterventionStatus", methods=["POST"])
- def quit_human_interventions_status():
- """
- 退出人工介入状态
- :return:
- """
- req_data = request.json
- staff_id = req_data["staff_id"]
- user_id = req_data["user_id"]
- if not user_id or not staff_id:
- return wrap_response(404, msg="user_id and staff_id are required")
- response = quit_human_intervention_status(user_id, staff_id)
- return wrap_response(200, data=response)
- ## Agent管理接口
- @app.route("/api/getNativeAgentList", methods=["GET"])
- def get_native_agent_list():
- """
- 获取所有的Agent列表
- :return:
- """
- page = request.args.get('page', 1)
- page_size = request.args.get('page_size', 50)
- create_user = request.args.get('create_user', None)
- update_user = request.args.get('update_user', None)
- offset = (int(page) - 1) * int(page_size)
- with app.session_maker() as session:
- query = session.query(AgentConfiguration) \
- .filter(AgentConfiguration.is_delete == 0)
- if create_user:
- query = query.filter(AgentConfiguration.create_user == create_user)
- if update_user:
- query = query.filter(AgentConfiguration.update_user == update_user)
- query = query.offset(offset).limit(int(page_size))
- data = query.all()
- ret_data = [
- {
- 'id': agent.id,
- 'name': agent.name,
- 'display_name': agent.display_name,
- 'type': agent.type,
- 'execution_model': agent.execution_model,
- 'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
- 'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
- }
- for agent in data
- ]
- return wrap_response(200, data=ret_data)
- @app.route("/api/getNativeAgentConfiguration", methods=["GET"])
- def get_native_agent_configuration():
- """
- 获取指定Agent的配置
- :return:
- """
- agent_id = request.args.get('agent_id')
- if not agent_id:
- return wrap_response(404, msg='agent_id is required')
- with app.session_maker() as session:
- agent = session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).first()
- if not agent:
- return wrap_response(404, msg='Agent not found')
- data = {
- 'id': agent.id,
- 'name': agent.name,
- 'display_name': agent.display_name,
- 'type': agent.type,
- 'execution_model': agent.execution_model,
- 'system_prompt': agent.system_prompt,
- 'task_prompt': agent.task_prompt,
- 'tools': agent.tools,
- 'sub_agents': agent.sub_agents,
- 'extra_params': agent.extra_params,
- 'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
- 'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
- }
- return wrap_response(200, data=data)
- @app.route("/api/saveNativeAgentConfiguration", methods=["POST"])
- def save_native_agent_configuration():
- """
- 保存Agent配置
- :return:
- """
- req_data = request.json
- agent_id = req_data.get('agent_id', None)
- name = req_data.get('name')
- display_name = req_data.get('display_name', None)
- type_ = req_data.get('type', 0)
- execution_model = req_data.get('execution_model', None)
- system_prompt = req_data.get('system_prompt', None)
- task_prompt = req_data.get('task_prompt', None)
- tools = req_data.get('tools', [])
- sub_agents = req_data.get('sub_agents', [])
- extra_params = req_data.get('extra_params', {})
- if not name:
- return wrap_response(400, msg='name is required')
- with app.session_maker() as session:
- if agent_id:
- agent_id = int(agent_id)
- agent = session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).first()
- if not agent:
- return wrap_response(404, msg='Agent not found')
- agent.name = name
- agent.display_name = display_name
- agent.type = type_
- agent.execution_model = execution_model
- agent.system_prompt = system_prompt
- agent.task_prompt = task_prompt
- agent.tools = tools
- agent.sub_agents = sub_agents
- agent.extra_params = extra_params
- else:
- agent = AgentConfiguration(
- name=name,
- display_name=display_name,
- type=type_,
- execution_model=execution_model,
- system_prompt=system_prompt,
- task_prompt=task_prompt,
- tools=tools,
- sub_agents=sub_agents,
- extra_params=extra_params
- )
- session.add(agent)
- session.commit()
- return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
- @app.route("/api/getModuleList", methods=["GET"])
- def get_module_list():
- """
- 获取所有的模块列表
- :return:
- """
- with app.session_maker() as session:
- query = session.query(ServiceModule) \
- .filter(ServiceModule.is_delete == 0)
- data = query.all()
- ret_data = [
- {
- 'id': module.id,
- 'name': module.name,
- 'display_name': module.display_name,
- 'default_agent_type': module.default_agent_type,
- 'default_agent_id': module.default_agent_id,
- 'create_time': module.create_time.strftime('%Y-%m-%d %H:%M:%S'),
- 'update_time': module.update_time.strftime('%Y-%m-%d %H:%M:%S')
- }
- for module in data
- ]
- return wrap_response(200, data=ret_data)
- @app.route("/api/getModuleConfiguration", methods=["GET"])
- def get_module_configuration():
- """
- 获取指定模块的配置
- :return:
- """
- module_id = request.args.get('module_id')
- if not module_id:
- return wrap_response(404, msg='module_id is required')
- with app.session_maker() as session:
- module = session.query(ServiceModule).filter(ServiceModule.id == module_id).first()
- if not module:
- return wrap_response(404, msg='Module not found')
- data = {
- 'id': module.id,
- 'name': module.name,
- 'display_name': module.display_name,
- 'default_agent_type': module.default_agent_type,
- 'default_agent_id': module.default_agent_id,
- 'create_time': module.create_time.strftime('%Y-%m-%d %H:%M:%S'),
- 'updated_time': module.updated_time.strftime('%Y-%m-%d %H:%M:%S')
- }
- return wrap_response(200, data=data)
- @app.route("/api/saveModuleConfiguration", methods=["POST"])
- def save_module_configuration():
- """
- 保存模块配置
- :return:
- """
- req_data = request.json
- module_id = req_data.get('module_id', None)
- name = req_data.get('name')
- display_name = req_data.get('display_name', None)
- default_agent_type = req_data.get('default_agent_type', 0)
- default_agent_id = req_data.get('default_agent_id', None)
- if not name:
- return wrap_response(400, msg='name is required')
- with app.session_maker() as session:
- if module_id:
- module_id = int(module_id)
- module = session.query(ServiceModule).filter(ServiceModule.id == module_id).first()
- if not module:
- return wrap_response(404, msg='Module not found')
- module.name = name
- module.display_name = display_name
- module.default_agent_type = default_agent_type
- module.default_agent_id = default_agent_id
- else:
- module = ServiceModule(
- name=name,
- display_name=display_name,
- default_agent_type=default_agent_type,
- default_agent_id=default_agent_id
- )
- session.add(module)
- session.commit()
- return wrap_response(200, msg='Module configuration saved successfully', data={'id': module.id})
- @app.errorhandler(werkzeug.exceptions.BadRequest)
- def handle_bad_request(e):
- logger.error(e)
- return wrap_response(400, msg='Bad Request: {}'.format(e.description))
- if __name__ == '__main__':
- parser = ArgumentParser()
- parser.add_argument('--prod', action='store_true')
- parser.add_argument('--host', default='127.0.0.1')
- parser.add_argument('--port', type=int, default=8083)
- parser.add_argument('--log-level', default='INFO')
- args = parser.parse_args()
- config = configs.get()
- logging_level = logging.getLevelName(args.log_level)
- logging_service.setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
- # set db config
- user_db_config = config['storage']['user']
- staff_db_config = config['storage']['staff']
- agent_state_db_config = config['storage']['agent_state']
- chat_history_db_config = config['storage']['chat_history']
- # init user manager
- user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
- app.user_manager = user_manager
- # init session manager
- session_manager = MySQLSessionManager(
- db_config=user_db_config['mysql'],
- staff_table=staff_db_config['table'],
- user_table=user_db_config['table'],
- agent_state_table=agent_state_db_config['table'],
- chat_history_table=chat_history_db_config['table']
- )
- app.session_manager = session_manager
- agent_db_engine = create_sql_engine(config['storage']['agent_state']['mysql'])
- app.session_maker = sessionmaker(bind=agent_db_engine)
- wecom_db_config = config['storage']['user_relation']
- user_relation_manager = MySQLUserRelationManager(
- user_db_config['mysql'], wecom_db_config['mysql'],
- config['storage']['staff']['table'],
- user_db_config['table'],
- wecom_db_config['table']['staff'],
- wecom_db_config['table']['relation'],
- wecom_db_config['table']['user']
- )
- app.user_relation_manager = user_relation_manager
- app.history_dialogue_service = HistoryDialogueService(
- config['storage']['history_dialogue']['api_base_url']
- )
- app.run(debug=not args.prod, host=args.host, port=args.port)
|