123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280 |
- #! /usr/bin/env python
- # -*- coding: utf-8 -*-
- # vim:fenc=utf-8
- import time
- import logging
- import werkzeug.exceptions
- from flask import Flask, request, jsonify
- from argparse import ArgumentParser
- from pqai_agent import configs
- from pqai_agent import logging_service, chat_service, prompt_templates
- from pqai_agent.history_dialogue_service import HistoryDialogueService
- from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
- from pqai_agent_server.utils import wrap_response
- from pqai_agent_server.utils import (
- run_extractor_prompt,
- run_chat_prompt,
- run_response_type_prompt,
- )
- app = Flask("agent_api_server")
- logger = logging_service.logger
- @app.route("/api/listStaffs", methods=["GET"])
- def list_staffs():
- staff_data = app.user_relation_manager.list_staffs()
- return wrap_response(200, data=staff_data)
- @app.route("/api/getStaffProfile", methods=["GET"])
- def get_staff_profile():
- staff_id = request.args["staff_id"]
- profile = app.user_manager.get_staff_profile(staff_id)
- if not profile:
- return wrap_response(404, msg="staff not found")
- else:
- return wrap_response(200, data=profile)
- @app.route("/api/getUserProfile", methods=["GET"])
- def get_user_profile():
- user_id = request.args["user_id"]
- profile = app.user_manager.get_user_profile(user_id)
- if not profile:
- resp = {"code": 404, "msg": "user not found"}
- else:
- resp = {"code": 200, "msg": "success", "data": profile}
- return jsonify(resp)
- @app.route("/api/listUsers", methods=["GET"])
- def list_users():
- user_name = request.args.get("user_name", None)
- user_union_id = request.args.get("user_union_id", None)
- if not user_name and not user_union_id:
- resp = {"code": 400, "msg": "user_name or user_union_id is required"}
- return jsonify(resp)
- data = app.user_manager.list_users(user_name=user_name, user_union_id=user_union_id)
- return jsonify({"code": 200, "data": data})
- @app.route("/api/getDialogueHistory", methods=["GET"])
- def get_dialogue_history():
- staff_id = request.args["staff_id"]
- user_id = request.args["user_id"]
- recent_minutes = int(request.args.get("recent_minutes", 1440))
- dialogue_history = app.history_dialogue_service.get_dialogue_history(
- staff_id, user_id, recent_minutes
- )
- return jsonify({"code": 200, "data": dialogue_history})
- @app.route("/api/listModels", methods=["GET"])
- def list_models():
- models = [
- {
- "model_type": "openai_compatible",
- "model_name": chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
- "display_name": "DeepSeek V3 on 火山",
- },
- {
- "model_type": "openai_compatible",
- "model_name": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
- "display_name": "豆包Pro 32K",
- },
- {
- "model_type": "openai_compatible",
- "model_name": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
- "display_name": "豆包Pro 1.5",
- },
- {
- "model_type": "openai_compatible",
- "model_name": chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH,
- "display_name": "DeepSeek V3联网 on 火山",
- },
- {
- "model_type": "openai_compatible",
- "model_name": chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
- "display_name": "豆包1.5视觉理解Pro",
- },
- ]
- return wrap_response(200, data=models)
- @app.route("/api/listScenes", methods=["GET"])
- def list_scenes():
- scenes = [
- {"scene": "greeting", "display_name": "问候"},
- {"scene": "chitchat", "display_name": "闲聊"},
- {"scene": "profile_extractor", "display_name": "画像提取"},
- {"scene": "response_type_detector", "display_name": "回复模态判断"},
- {"scene": "custom_debugging", "display_name": "自定义调试场景"},
- ]
- return wrap_response(200, data=scenes)
- @app.route("/api/getBasePrompt", methods=["GET"])
- def get_base_prompt():
- scene = request.args["scene"]
- prompt_map = {
- "greeting": prompt_templates.GENERAL_GREETING_PROMPT,
- "chitchat": prompt_templates.CHITCHAT_PROMPT_COZE,
- "profile_extractor": prompt_templates.USER_PROFILE_EXTRACT_PROMPT,
- "response_type_detector": prompt_templates.RESPONSE_TYPE_DETECT_PROMPT,
- "custom_debugging": "",
- }
- model_map = {
- "greeting": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
- "chitchat": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
- "profile_extractor": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
- "response_type_detector": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
- "custom_debugging": chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH,
- }
- if scene not in prompt_map:
- return wrap_response(404, msg="scene not found")
- data = {"model_name": model_map[scene], "content": prompt_map[scene]}
- return wrap_response(200, data=data)
- @app.route("/api/runPrompt", methods=["POST"])
- def run_prompt():
- try:
- req_data = request.json
- logger.debug(req_data)
- scene = req_data["scene"]
- if scene == "profile_extractor":
- response = run_extractor_prompt(req_data)
- return wrap_response(200, data=response)
- elif scene == "response_type_detector":
- response = run_response_type_prompt(req_data)
- return wrap_response(200, data=response.choices[0].message.content)
- else:
- response = run_chat_prompt(req_data)
- return wrap_response(200, data=response.choices[0].message.content)
- except Exception as e:
- logger.error(e)
- return wrap_response(500, msg="Error: {}".format(e))
- @app.route("/api/healthCheck", methods=["GET"])
- def health_check():
- return wrap_response(200, msg="OK")
- @app.route("/api/getStaffSessionSummary", methods=["GET"])
- def get_staff_session_summary():
- staff_id = request.args.get("staff_id")
- status = request.args.get("status", 1)
- page_id = request.args.get("page_id", 1)
- page_size = request.args.get("page_size", 10)
- # check params
- try:
- page_id = int(page_id)
- page_size = int(page_size)
- status = int(status)
- except Exception as e:
- return wrap_response(404, msg="Invalid parameter: {}".format(e))
- staff_session_summary = app.user_manager.get_staff_sessions_summary_v1(
- staff_id, page_id, page_size, status
- )
- if not staff_session_summary:
- return wrap_response(404, msg="staff not found")
- else:
- return wrap_response(200, data=staff_session_summary)
- @app.route("/api/getStaffSessionList", methods=["GET"])
- def get_staff_session_list():
- staff_id = request.args.get("staff_id")
- if not staff_id:
- return wrap_response(404, msg="staff_id is required")
- page_size = request.args.get("page_size", 10)
- page_id = request.args.get("page_id", 1)
- staff_session_list = app.user_manager.get_staff_session_list_v1(staff_id, page_id, page_size)
- if not staff_session_list:
- return wrap_response(404, msg="staff not found")
- return wrap_response(200, data=staff_session_list)
- @app.route("/api/getStaffList", methods=["GET"])
- def get_staff_list():
- page_size = request.args.get("page_size", 10)
- page_id = request.args.get("page_id", 1)
- staff_list = app.user_manager.get_staff_list(page_id, page_size)
- if not staff_list:
- return wrap_response(404, msg="staff not found")
- return wrap_response(200, data=staff_list)
- @app.route("/api/getConversationList", methods=["GET"])
- def get_conversation_list():
- """
- 获取staff && customer的 私聊对话列表
- :return:
- """
- staff_id = request.args.get("staff_id")
- customer_id = request.args.get("customer_id")
- if not staff_id or not customer_id:
- return wrap_response(404, msg="staff_id and customer_id are required")
- page = request.args.get("page")
- response = app.user_manager.get_conversation_list_v1(staff_id, customer_id, page)
- return wrap_response(200, data=response)
- @app.route("/api/sendMessage", methods=["POST"])
- def send_message():
- return wrap_response(200, msg="暂不实现功能")
- @app.errorhandler(werkzeug.exceptions.BadRequest)
- def handle_bad_request(e):
- logger.error(e)
- return wrap_response(400, msg="Bad Request: {}".format(e.description))
- if __name__ == "__main__":
- parser = ArgumentParser()
- parser.add_argument("--prod", action="store_true")
- parser.add_argument("--host", default="127.0.0.1")
- parser.add_argument("--port", type=int, default=8083)
- parser.add_argument("--log-level", default="INFO")
- args = parser.parse_args()
- config = configs.get()
- logging_level = logging.getLevelName(args.log_level)
- logging_service.setup_root_logger(
- level=logging_level, logfile_name="agent_api_server.log"
- )
- user_db_config = config["storage"]["user"]
- staff_db_config = config["storage"]["staff"]
- user_manager = MySQLUserManager(
- user_db_config["mysql"], user_db_config["table"], staff_db_config["table"]
- )
- app.user_manager = user_manager
- wecom_db_config = config["storage"]["user_relation"]
- user_relation_manager = MySQLUserRelationManager(
- user_db_config["mysql"],
- wecom_db_config["mysql"],
- config["storage"]["staff"]["table"],
- user_db_config["table"],
- wecom_db_config["table"]["staff"],
- wecom_db_config["table"]["relation"],
- wecom_db_config["table"]["user"],
- )
- app.user_relation_manager = user_relation_manager
- app.history_dialogue_service = HistoryDialogueService(
- config["storage"]["history_dialogue"]["api_base_url"]
- )
- app.run(debug=not args.prod, host=args.host, port=args.port)
|