api_server.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import logging
  5. import werkzeug.exceptions
  6. from flask import Flask, request, jsonify
  7. from argparse import ArgumentParser
  8. from pqai_agent import configs
  9. from pqai_agent import logging_service, chat_service, prompt_templates
  10. from pqai_agent.history_dialogue_service import HistoryDialogueService
  11. from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
  12. from pqai_agent_server.const import AgentApiConst
  13. from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
  14. from pqai_agent_server.utils import (
  15. run_extractor_prompt,
  16. run_chat_prompt,
  17. run_response_type_prompt,
  18. )
  19. app = Flask("agent_api_server")
  20. logger = logging_service.logger
  21. const = AgentApiConst()
  22. @app.route("/api/listStaffs", methods=["GET"])
  23. def list_staffs():
  24. staff_data = app.user_relation_manager.list_staffs()
  25. return wrap_response(200, data=staff_data)
  26. @app.route("/api/getStaffProfile", methods=["GET"])
  27. def get_staff_profile():
  28. staff_id = request.args["staff_id"]
  29. profile = app.user_manager.get_staff_profile(staff_id)
  30. if not profile:
  31. return wrap_response(404, msg="staff not found")
  32. else:
  33. return wrap_response(200, data=profile)
  34. @app.route("/api/getUserProfile", methods=["GET"])
  35. def get_user_profile():
  36. user_id = request.args["user_id"]
  37. profile = app.user_manager.get_user_profile(user_id)
  38. if not profile:
  39. resp = {"code": 404, "msg": "user not found"}
  40. else:
  41. resp = {"code": 200, "msg": "success", "data": profile}
  42. return jsonify(resp)
  43. @app.route("/api/listUsers", methods=["GET"])
  44. def list_users():
  45. user_name = request.args.get("user_name", None)
  46. user_union_id = request.args.get("user_union_id", None)
  47. if not user_name and not user_union_id:
  48. resp = {"code": 400, "msg": "user_name or user_union_id is required"}
  49. return jsonify(resp)
  50. data = app.user_manager.list_users(user_name=user_name, user_union_id=user_union_id)
  51. return jsonify({"code": 200, "data": data})
  52. @app.route("/api/getDialogueHistory", methods=["GET"])
  53. def get_dialogue_history():
  54. staff_id = request.args["staff_id"]
  55. user_id = request.args["user_id"]
  56. recent_minutes = int(request.args.get("recent_minutes", 1440))
  57. dialogue_history = app.history_dialogue_service.get_dialogue_history(
  58. staff_id, user_id, recent_minutes
  59. )
  60. return jsonify({"code": 200, "data": dialogue_history})
  61. @app.route("/api/listModels", methods=["GET"])
  62. def list_models():
  63. models = [
  64. {
  65. "model_type": "openai_compatible",
  66. "model_name": chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
  67. "display_name": "DeepSeek V3 on 火山",
  68. },
  69. {
  70. "model_type": "openai_compatible",
  71. "model_name": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  72. "display_name": "豆包Pro 32K",
  73. },
  74. {
  75. "model_type": "openai_compatible",
  76. "model_name": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  77. "display_name": "豆包Pro 1.5",
  78. },
  79. {
  80. "model_type": "openai_compatible",
  81. "model_name": chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH,
  82. "display_name": "DeepSeek V3联网 on 火山",
  83. },
  84. {
  85. "model_type": "openai_compatible",
  86. "model_name": chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
  87. "display_name": "豆包1.5视觉理解Pro",
  88. },
  89. ]
  90. return wrap_response(200, data=models)
  91. @app.route("/api/listScenes", methods=["GET"])
  92. def list_scenes():
  93. scenes = [
  94. {"scene": "greeting", "display_name": "问候"},
  95. {"scene": "chitchat", "display_name": "闲聊"},
  96. {"scene": "profile_extractor", "display_name": "画像提取"},
  97. {"scene": "response_type_detector", "display_name": "回复模态判断"},
  98. {"scene": "custom_debugging", "display_name": "自定义调试场景"},
  99. ]
  100. return wrap_response(200, data=scenes)
  101. @app.route("/api/getBasePrompt", methods=["GET"])
  102. def get_base_prompt():
  103. scene = request.args["scene"]
  104. prompt_map = {
  105. "greeting": prompt_templates.GENERAL_GREETING_PROMPT,
  106. "chitchat": prompt_templates.CHITCHAT_PROMPT_COZE,
  107. "profile_extractor": prompt_templates.USER_PROFILE_EXTRACT_PROMPT,
  108. "response_type_detector": prompt_templates.RESPONSE_TYPE_DETECT_PROMPT,
  109. "custom_debugging": "",
  110. }
  111. model_map = {
  112. "greeting": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  113. "chitchat": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  114. "profile_extractor": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  115. "response_type_detector": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  116. "custom_debugging": chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH,
  117. }
  118. if scene not in prompt_map:
  119. return wrap_response(404, msg="scene not found")
  120. data = {"model_name": model_map[scene], "content": prompt_map[scene]}
  121. return wrap_response(200, data=data)
  122. @app.route("/api/runPrompt", methods=["POST"])
  123. def run_prompt():
  124. try:
  125. req_data = request.json
  126. logger.debug(req_data)
  127. scene = req_data["scene"]
  128. if scene == "profile_extractor":
  129. response = run_extractor_prompt(req_data)
  130. return wrap_response(200, data=response)
  131. elif scene == "response_type_detector":
  132. response = run_response_type_prompt(req_data)
  133. return wrap_response(200, data=response.choices[0].message.content)
  134. else:
  135. response = run_chat_prompt(req_data)
  136. return wrap_response(200, data=response.choices[0].message.content)
  137. except Exception as e:
  138. logger.error(e)
  139. return wrap_response(500, msg="Error: {}".format(e))
  140. @app.route("/api/healthCheck", methods=["GET"])
  141. def health_check():
  142. return wrap_response(200, msg="OK")
  143. @app.route("/api/getStaffSessionSummary", methods=["GET"])
  144. def get_staff_session_summary():
  145. staff_id = request.args.get("staff_id")
  146. status = request.args.get("status", const.DEFAULT_STAFF_STATUS)
  147. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  148. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  149. # check params
  150. try:
  151. page_id = int(page_id)
  152. page_size = int(page_size)
  153. status = int(status)
  154. except Exception as e:
  155. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  156. staff_session_summary = app.user_manager.get_staff_sessions_summary_v1(
  157. staff_id, page_id, page_size, status
  158. )
  159. if not staff_session_summary:
  160. return wrap_response(404, msg="staff not found")
  161. else:
  162. return wrap_response(200, data=staff_session_summary)
  163. @app.route("/api/getStaffSessionList", methods=["GET"])
  164. def get_staff_session_list():
  165. staff_id = request.args.get("staff_id")
  166. if not staff_id:
  167. return wrap_response(404, msg="staff_id is required")
  168. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  169. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  170. staff_session_list = app.user_manager.get_staff_session_list_v1(
  171. staff_id, page_id, page_size
  172. )
  173. if not staff_session_list:
  174. return wrap_response(404, msg="staff not found")
  175. return wrap_response(200, data=staff_session_list)
  176. @app.route("/api/getStaffList", methods=["GET"])
  177. def get_staff_list():
  178. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  179. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  180. staff_list = app.user_manager.get_staff_list(page_id, page_size)
  181. if not staff_list:
  182. return wrap_response(404, msg="staff not found")
  183. return wrap_response(200, data=staff_list)
  184. @app.route("/api/getConversationList", methods=["GET"])
  185. def get_conversation_list():
  186. """
  187. 获取staff && customer的 私聊对话列表
  188. :return:
  189. """
  190. staff_id = request.args.get("staff_id")
  191. customer_id = request.args.get("customer_id")
  192. if not staff_id or not customer_id:
  193. return wrap_response(404, msg="staff_id and customer_id are required")
  194. page = request.args.get("page")
  195. response = app.user_manager.get_conversation_list_v1(staff_id, customer_id, page, const.DEFAULT_CONVERSATION_SIZE)
  196. return wrap_response(200, data=response)
  197. @app.route("/api/quitHumanInterventionStatus", methods=["GET"])
  198. def quit_human_interventions_status():
  199. """
  200. 退出人工介入状态
  201. :return:
  202. """
  203. staff_id = request.args.get("staff_id")
  204. customer_id = request.args.get("customer_id")
  205. # 测试环境: staff_id 强制等于1688854492669990
  206. staff_id = 1688854492669990
  207. if not customer_id or not staff_id:
  208. return wrap_response(404, msg="user_id and staff_id are required")
  209. response = quit_human_intervention_status(customer_id, staff_id)
  210. return wrap_response(200, data=response)
  211. @app.route("/api/sendMessage", methods=["POST"])
  212. def send_message():
  213. return wrap_response(200, msg="暂不实现功能")
  214. @app.errorhandler(werkzeug.exceptions.BadRequest)
  215. def handle_bad_request(e):
  216. logger.error(e)
  217. return wrap_response(400, msg="Bad Request: {}".format(e.description))
  218. if __name__ == "__main__":
  219. parser = ArgumentParser()
  220. parser.add_argument("--prod", action="store_true")
  221. parser.add_argument("--host", default="127.0.0.1")
  222. parser.add_argument("--port", type=int, default=8083)
  223. parser.add_argument("--log-level", default="INFO")
  224. args = parser.parse_args()
  225. config = configs.get()
  226. logging_level = logging.getLevelName(args.log_level)
  227. logging_service.setup_root_logger(
  228. level=logging_level, logfile_name="agent_api_server.log"
  229. )
  230. user_db_config = config["storage"]["user"]
  231. staff_db_config = config["storage"]["staff"]
  232. user_manager = MySQLUserManager(
  233. user_db_config["mysql"], user_db_config["table"], staff_db_config["table"]
  234. )
  235. app.user_manager = user_manager
  236. wecom_db_config = config["storage"]["user_relation"]
  237. user_relation_manager = MySQLUserRelationManager(
  238. user_db_config["mysql"],
  239. wecom_db_config["mysql"],
  240. config["storage"]["staff"]["table"],
  241. user_db_config["table"],
  242. wecom_db_config["table"]["staff"],
  243. wecom_db_config["table"]["relation"],
  244. wecom_db_config["table"]["user"],
  245. )
  246. app.user_relation_manager = user_relation_manager
  247. app.history_dialogue_service = HistoryDialogueService(
  248. config["storage"]["history_dialogue"]["api_base_url"]
  249. )
  250. app.run(debug=not args.prod, host=args.host, port=args.port)