api_server.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import time
  5. import logging
  6. import werkzeug.exceptions
  7. from flask import Flask, request, jsonify
  8. from argparse import ArgumentParser
  9. from pqai_agent import configs
  10. from pqai_agent import logging_service, chat_service, prompt_templates
  11. from pqai_agent.history_dialogue_service import HistoryDialogueService
  12. from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
  13. from pqai_agent_server.const import AgentApiConst
  14. from pqai_agent_server.models import MySQLSessionManager
  15. from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
  16. from pqai_agent_server.utils import (
  17. run_extractor_prompt,
  18. run_chat_prompt,
  19. run_response_type_prompt,
  20. )
  21. app = Flask('agent_api_server')
  22. logger = logging_service.logger
  23. const = AgentApiConst()
  24. @app.route('/api/listStaffs', methods=['GET'])
  25. def list_staffs():
  26. staff_data = app.user_relation_manager.list_staffs()
  27. return wrap_response(200, data=staff_data)
  28. @app.route('/api/getStaffProfile', methods=['GET'])
  29. def get_staff_profile():
  30. staff_id = request.args['staff_id']
  31. profile = app.user_manager.get_staff_profile(staff_id)
  32. if not profile:
  33. return wrap_response(404, msg='staff not found')
  34. else:
  35. return wrap_response(200, data=profile)
  36. @app.route('/api/getUserProfile', methods=['GET'])
  37. def get_user_profile():
  38. user_id = request.args['user_id']
  39. profile = app.user_manager.get_user_profile(user_id)
  40. if not profile:
  41. resp = {
  42. 'code': 404,
  43. 'msg': 'user not found'
  44. }
  45. else:
  46. resp = {
  47. 'code': 200,
  48. 'msg': 'success',
  49. 'data': profile
  50. }
  51. return jsonify(resp)
  52. @app.route('/api/listUsers', methods=['GET'])
  53. def list_users():
  54. user_name = request.args.get('user_name', None)
  55. user_union_id = request.args.get('user_union_id', None)
  56. if not user_name and not user_union_id:
  57. resp = {
  58. 'code': 400,
  59. 'msg': 'user_name or user_union_id is required'
  60. }
  61. return jsonify(resp)
  62. data = app.user_manager.list_users(user_name=user_name, user_union_id=user_union_id)
  63. return jsonify({'code': 200, 'data': data})
  64. @app.route('/api/getDialogueHistory', methods=['GET'])
  65. def get_dialogue_history():
  66. staff_id = request.args['staff_id']
  67. user_id = request.args['user_id']
  68. recent_minutes = int(request.args.get('recent_minutes', 1440))
  69. dialogue_history = app.history_dialogue_service.get_dialogue_history(staff_id, user_id, recent_minutes)
  70. return jsonify({'code': 200, 'data': dialogue_history})
  71. @app.route('/api/listModels', methods=['GET'])
  72. def list_models():
  73. models = [
  74. {
  75. 'model_type': 'openai_compatible',
  76. 'model_name': chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
  77. 'display_name': 'DeepSeek V3 on 火山'
  78. },
  79. {
  80. 'model_type': 'openai_compatible',
  81. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  82. 'display_name': '豆包Pro 32K'
  83. },
  84. {
  85. 'model_type': 'openai_compatible',
  86. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  87. 'display_name': '豆包Pro 1.5'
  88. },
  89. {
  90. 'model_type': 'openai_compatible',
  91. 'model_name': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH,
  92. 'display_name': 'DeepSeek V3联网 on 火山'
  93. },
  94. {
  95. 'model_type': 'openai_compatible',
  96. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
  97. 'display_name': '豆包1.5视觉理解Pro'
  98. },
  99. ]
  100. return wrap_response(200, data=models)
  101. @app.route('/api/listScenes', methods=['GET'])
  102. def list_scenes():
  103. scenes = [
  104. {'scene': 'greeting', 'display_name': '问候'},
  105. {'scene': 'chitchat', 'display_name': '闲聊'},
  106. {'scene': 'profile_extractor', 'display_name': '画像提取'},
  107. {'scene': 'response_type_detector', 'display_name': '回复模态判断'},
  108. {'scene': 'custom_debugging', 'display_name': '自定义调试场景'}
  109. ]
  110. return wrap_response(200, data=scenes)
  111. @app.route('/api/getBasePrompt', methods=['GET'])
  112. def get_base_prompt():
  113. scene = request.args['scene']
  114. prompt_map = {
  115. 'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
  116. 'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
  117. 'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT,
  118. 'response_type_detector': prompt_templates.RESPONSE_TYPE_DETECT_PROMPT,
  119. 'custom_debugging': '',
  120. }
  121. model_map = {
  122. 'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  123. 'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  124. 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  125. 'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  126. 'custom_debugging': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH
  127. }
  128. if scene not in prompt_map:
  129. return wrap_response(404, msg='scene not found')
  130. data = {
  131. 'model_name': model_map[scene],
  132. 'content': prompt_map[scene]
  133. }
  134. return wrap_response(200, data=data)
  135. @app.route('/api/runPrompt', methods=['POST'])
  136. def run_prompt():
  137. try:
  138. req_data = request.json
  139. logger.debug(req_data)
  140. scene = req_data['scene']
  141. if scene == 'profile_extractor':
  142. response = run_extractor_prompt(req_data)
  143. return wrap_response(200, data=response)
  144. elif scene == 'response_type_detector':
  145. response = run_response_type_prompt(req_data)
  146. return wrap_response(200, data=response.choices[0].message.content)
  147. else:
  148. response = run_chat_prompt(req_data)
  149. return wrap_response(200, data=response.choices[0].message.content)
  150. except Exception as e:
  151. logger.error(e)
  152. return wrap_response(500, msg='Error: {}'.format(e))
  153. @app.route("/api/healthCheck", methods=["GET"])
  154. def health_check():
  155. return wrap_response(200, msg="OK")
  156. @app.route("/api/getStaffSessionSummary", methods=["GET"])
  157. def get_staff_session_summary():
  158. staff_id = request.args.get("staff_id")
  159. status = request.args.get("status", const.DEFAULT_STAFF_STATUS)
  160. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  161. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  162. # check params
  163. try:
  164. page_id = int(page_id)
  165. page_size = int(page_size)
  166. status = int(status)
  167. except Exception as e:
  168. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  169. staff_session_summary = app.session_manager.get_staff_sessions_summary(
  170. staff_id, page_id, page_size, status
  171. )
  172. if not staff_session_summary:
  173. return wrap_response(404, msg="staff not found")
  174. else:
  175. return wrap_response(200, data=staff_session_summary)
  176. @app.route("/api/getStaffSessionList", methods=["GET"])
  177. def get_staff_session_list():
  178. staff_id = request.args.get("staff_id")
  179. if not staff_id:
  180. return wrap_response(404, msg="staff_id is required")
  181. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  182. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  183. staff_session_list = app.session_manager.get_staff_session_list(staff_id, page_id, page_size)
  184. if not staff_session_list:
  185. return wrap_response(404, msg="staff not found")
  186. return wrap_response(200, data=staff_session_list)
  187. @app.route("/api/getStaffList", methods=["GET"])
  188. def get_staff_list():
  189. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  190. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  191. staff_list = app.user_manager.get_staff_list(page_id, page_size)
  192. if not staff_list:
  193. return wrap_response(404, msg="staff not found")
  194. return wrap_response(200, data=staff_list)
  195. @app.route("/api/getConversationList", methods=["GET"])
  196. def get_conversation_list():
  197. """
  198. 获取staff && customer的 私聊对话列表
  199. :return:
  200. """
  201. staff_id = request.args.get("staff_id")
  202. customer_id = request.args.get("customer_id")
  203. if not staff_id or not customer_id:
  204. return wrap_response(404, msg="staff_id and customer_id are required")
  205. page = request.args.get("page")
  206. response = app.session_manager.get_conversation_list(staff_id, customer_id, page, const.DEFAULT_CONVERSATION_SIZE)
  207. return wrap_response(200, data=response)
  208. @app.route("/api/sendMessage", methods=["POST"])
  209. def send_message():
  210. return wrap_response(200, msg="暂不实现功能")
  211. @app.route("/api/quitHumanInterventionStatus", methods=["GET"])
  212. def quit_human_interventions_status():
  213. """
  214. 退出人工介入状态
  215. :return:
  216. """
  217. staff_id = request.args.get("staff_id")
  218. customer_id = request.args.get("customer_id")
  219. # 测试环境: staff_id 强制等于1688854492669990
  220. staff_id = 1688854492669990
  221. if not customer_id or not staff_id:
  222. return wrap_response(404, msg="user_id and staff_id are required")
  223. response = quit_human_intervention_status(customer_id, staff_id)
  224. return wrap_response(200, data=response)
  225. @app.errorhandler(werkzeug.exceptions.BadRequest)
  226. def handle_bad_request(e):
  227. logger.error(e)
  228. return wrap_response(400, msg='Bad Request: {}'.format(e.description))
  229. if __name__ == '__main__':
  230. parser = ArgumentParser()
  231. parser.add_argument('--prod', action='store_true')
  232. parser.add_argument('--host', default='127.0.0.1')
  233. parser.add_argument('--port', type=int, default=8083)
  234. parser.add_argument('--log-level', default='INFO')
  235. args = parser.parse_args()
  236. config = configs.get()
  237. logging_level = logging.getLevelName(args.log_level)
  238. logging_service.setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
  239. # set db config
  240. user_db_config = config['storage']['user']
  241. staff_db_config = config['storage']['staff']
  242. agent_state_db_config = config['storage']['agent_state']
  243. chat_history_db_config = config['storage']['chat_history']
  244. # init user manager
  245. user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
  246. app.user_manager = user_manager
  247. # init session manager
  248. session_manager = MySQLSessionManager(
  249. db_config=user_db_config['mysql'],
  250. staff_table=staff_db_config['table'],
  251. user_table=user_db_config['table'],
  252. agent_state_table=agent_state_db_config['table'],
  253. chat_history_table=chat_history_db_config['table']
  254. )
  255. app.session_manager = session_manager
  256. wecom_db_config = config['storage']['user_relation']
  257. user_relation_manager = MySQLUserRelationManager(
  258. user_db_config['mysql'], wecom_db_config['mysql'],
  259. config['storage']['staff']['table'],
  260. user_db_config['table'],
  261. wecom_db_config['table']['staff'],
  262. wecom_db_config['table']['relation'],
  263. wecom_db_config['table']['user']
  264. )
  265. app.user_relation_manager = user_relation_manager
  266. app.history_dialogue_service = HistoryDialogueService(
  267. config['storage']['history_dialogue']['api_base_url']
  268. )
  269. app.run(debug=not args.prod, host=args.host, port=args.port)