#! /usr/bin/env python # -*- coding: utf-8 -*- # vim:fenc=utf-8 import logging import werkzeug.exceptions from flask import Flask, request, jsonify from argparse import ArgumentParser from pqai_agent import configs from pqai_agent import logging_service, chat_service, prompt_templates from pqai_agent.history_dialogue_service import HistoryDialogueService from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager from pqai_agent_server.const import AgentApiConst from pqai_agent_server.utils import wrap_response, quit_human_intervention_status from pqai_agent_server.utils import ( run_extractor_prompt, run_chat_prompt, run_response_type_prompt, ) app = Flask("agent_api_server") logger = logging_service.logger const = AgentApiConst() @app.route("/api/listStaffs", methods=["GET"]) def list_staffs(): staff_data = app.user_relation_manager.list_staffs() return wrap_response(200, data=staff_data) @app.route("/api/getStaffProfile", methods=["GET"]) def get_staff_profile(): staff_id = request.args["staff_id"] profile = app.user_manager.get_staff_profile(staff_id) if not profile: return wrap_response(404, msg="staff not found") else: return wrap_response(200, data=profile) @app.route("/api/getUserProfile", methods=["GET"]) def get_user_profile(): user_id = request.args["user_id"] profile = app.user_manager.get_user_profile(user_id) if not profile: resp = {"code": 404, "msg": "user not found"} else: resp = {"code": 200, "msg": "success", "data": profile} return jsonify(resp) @app.route("/api/listUsers", methods=["GET"]) def list_users(): user_name = request.args.get("user_name", None) user_union_id = request.args.get("user_union_id", None) if not user_name and not user_union_id: resp = {"code": 400, "msg": "user_name or user_union_id is required"} return jsonify(resp) data = app.user_manager.list_users(user_name=user_name, user_union_id=user_union_id) return jsonify({"code": 200, "data": data}) @app.route("/api/getDialogueHistory", methods=["GET"]) def get_dialogue_history(): staff_id = request.args["staff_id"] user_id = request.args["user_id"] recent_minutes = int(request.args.get("recent_minutes", 1440)) dialogue_history = app.history_dialogue_service.get_dialogue_history( staff_id, user_id, recent_minutes ) return jsonify({"code": 200, "data": dialogue_history}) @app.route("/api/listModels", methods=["GET"]) def list_models(): models = [ { "model_type": "openai_compatible", "model_name": chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3, "display_name": "DeepSeek V3 on 火山", }, { "model_type": "openai_compatible", "model_name": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K, "display_name": "豆包Pro 32K", }, { "model_type": "openai_compatible", "model_name": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5, "display_name": "豆包Pro 1.5", }, { "model_type": "openai_compatible", "model_name": chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH, "display_name": "DeepSeek V3联网 on 火山", }, { "model_type": "openai_compatible", "model_name": chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO, "display_name": "豆包1.5视觉理解Pro", }, ] return wrap_response(200, data=models) @app.route("/api/listScenes", methods=["GET"]) def list_scenes(): scenes = [ {"scene": "greeting", "display_name": "问候"}, {"scene": "chitchat", "display_name": "闲聊"}, {"scene": "profile_extractor", "display_name": "画像提取"}, {"scene": "response_type_detector", "display_name": "回复模态判断"}, {"scene": "custom_debugging", "display_name": "自定义调试场景"}, ] return wrap_response(200, data=scenes) @app.route("/api/getBasePrompt", methods=["GET"]) def get_base_prompt(): scene = request.args["scene"] prompt_map = { "greeting": prompt_templates.GENERAL_GREETING_PROMPT, "chitchat": prompt_templates.CHITCHAT_PROMPT_COZE, "profile_extractor": prompt_templates.USER_PROFILE_EXTRACT_PROMPT, "response_type_detector": prompt_templates.RESPONSE_TYPE_DETECT_PROMPT, "custom_debugging": "", } model_map = { "greeting": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K, "chitchat": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K, "profile_extractor": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5, "response_type_detector": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5, "custom_debugging": chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH, } if scene not in prompt_map: return wrap_response(404, msg="scene not found") data = {"model_name": model_map[scene], "content": prompt_map[scene]} return wrap_response(200, data=data) @app.route("/api/runPrompt", methods=["POST"]) def run_prompt(): try: req_data = request.json logger.debug(req_data) scene = req_data["scene"] if scene == "profile_extractor": response = run_extractor_prompt(req_data) return wrap_response(200, data=response) elif scene == "response_type_detector": response = run_response_type_prompt(req_data) return wrap_response(200, data=response.choices[0].message.content) else: response = run_chat_prompt(req_data) return wrap_response(200, data=response.choices[0].message.content) except Exception as e: logger.error(e) return wrap_response(500, msg="Error: {}".format(e)) @app.route("/api/healthCheck", methods=["GET"]) def health_check(): return wrap_response(200, msg="OK") @app.route("/api/getStaffSessionSummary", methods=["GET"]) def get_staff_session_summary(): staff_id = request.args.get("staff_id") status = request.args.get("status", const.DEFAULT_STAFF_STATUS) page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID) page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE) # check params try: page_id = int(page_id) page_size = int(page_size) status = int(status) except Exception as e: return wrap_response(404, msg="Invalid parameter: {}".format(e)) staff_session_summary = app.user_manager.get_staff_sessions_summary_v1( staff_id, page_id, page_size, status ) if not staff_session_summary: return wrap_response(404, msg="staff not found") else: return wrap_response(200, data=staff_session_summary) @app.route("/api/getStaffSessionList", methods=["GET"]) def get_staff_session_list(): staff_id = request.args.get("staff_id") if not staff_id: return wrap_response(404, msg="staff_id is required") page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE) page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID) staff_session_list = app.user_manager.get_staff_session_list_v1( staff_id, page_id, page_size ) if not staff_session_list: return wrap_response(404, msg="staff not found") return wrap_response(200, data=staff_session_list) @app.route("/api/getStaffList", methods=["GET"]) def get_staff_list(): page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE) page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID) staff_list = app.user_manager.get_staff_list(page_id, page_size) if not staff_list: return wrap_response(404, msg="staff not found") return wrap_response(200, data=staff_list) @app.route("/api/getConversationList", methods=["GET"]) def get_conversation_list(): """ 获取staff && customer的 私聊对话列表 :return: """ staff_id = request.args.get("staff_id") customer_id = request.args.get("customer_id") if not staff_id or not customer_id: return wrap_response(404, msg="staff_id and customer_id are required") page = request.args.get("page") response = app.user_manager.get_conversation_list_v1(staff_id, customer_id, page, const.DEFAULT_CONVERSATION_SIZE) return wrap_response(200, data=response) @app.route("/api/quitHumanInterventionStatus", methods=["GET"]) def quit_human_interventions_status(): """ 退出人工介入状态 :return: """ staff_id = request.args.get("staff_id") customer_id = request.args.get("customer_id") # 测试环境: staff_id 强制等于1688854492669990 staff_id = 1688854492669990 if not customer_id or not staff_id: return wrap_response(404, msg="user_id and staff_id are required") response = quit_human_intervention_status(customer_id, staff_id) return wrap_response(200, data=response) @app.route("/api/sendMessage", methods=["POST"]) def send_message(): return wrap_response(200, msg="暂不实现功能") @app.errorhandler(werkzeug.exceptions.BadRequest) def handle_bad_request(e): logger.error(e) return wrap_response(400, msg="Bad Request: {}".format(e.description)) if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--prod", action="store_true") parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--port", type=int, default=8083) parser.add_argument("--log-level", default="INFO") args = parser.parse_args() config = configs.get() logging_level = logging.getLevelName(args.log_level) logging_service.setup_root_logger( level=logging_level, logfile_name="agent_api_server.log" ) user_db_config = config["storage"]["user"] staff_db_config = config["storage"]["staff"] user_manager = MySQLUserManager( user_db_config["mysql"], user_db_config["table"], staff_db_config["table"] ) app.user_manager = user_manager wecom_db_config = config["storage"]["user_relation"] user_relation_manager = MySQLUserRelationManager( user_db_config["mysql"], wecom_db_config["mysql"], config["storage"]["staff"]["table"], user_db_config["table"], wecom_db_config["table"]["staff"], wecom_db_config["table"]["relation"], wecom_db_config["table"]["user"], ) app.user_relation_manager = user_relation_manager app.history_dialogue_service = HistoryDialogueService( config["storage"]["history_dialogue"]["api_base_url"] ) app.run(debug=not args.prod, host=args.host, port=args.port)