api_server.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import json
  5. import time
  6. import logging
  7. import werkzeug.exceptions
  8. from flask import Flask, request, jsonify
  9. from argparse import ArgumentParser
  10. from pqai_agent import configs
  11. from pqai_agent import logging_service, chat_service, prompt_templates
  12. from pqai_agent.agents.message_reply_agent import MessageReplyAgent
  13. from pqai_agent.configs import apollo_config
  14. from pqai_agent.history_dialogue_service import HistoryDialogueService
  15. from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
  16. from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
  17. from pqai_agent_server.const import AgentApiConst
  18. from pqai_agent_server.models import MySQLSessionManager, MySQLStaffManager
  19. from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
  20. from pqai_agent_server.utils import (
  21. run_extractor_prompt,
  22. run_chat_prompt,
  23. run_response_type_prompt,
  24. )
  25. app = Flask('agent_api_server')
  26. logger = logging_service.logger
  27. const = AgentApiConst()
  28. @app.route('/api/listStaffs', methods=['GET'])
  29. def list_staffs():
  30. staff_data = app.user_relation_manager.list_staffs()
  31. return wrap_response(200, data=staff_data)
  32. @app.route('/api/getStaffProfile', methods=['GET'])
  33. def get_staff_profile():
  34. staff_id = request.args['staff_id']
  35. if not staff_id:
  36. return wrap_response(400, msg='staff_id is required')
  37. agent_profile = app.staff_manager.get_staff_profile(staff_id)
  38. if not agent_profile:
  39. return wrap_response(404, msg='staff not found')
  40. else:
  41. field_map_list = apollo_config.get_json_value("field_map_list", [])
  42. field_map = {
  43. item["field_name"]: item["display_name"] for item in field_map_list
  44. }
  45. profile_info = [
  46. {
  47. "field_name": key,
  48. "display_name": field_map[key],
  49. "field_value": value,
  50. }
  51. for key, value in agent_profile.items()
  52. if agent_profile.get(key)
  53. ]
  54. return wrap_response(200, data=profile_info)
  55. @app.route('/api/saveStaffProfile', methods=['POST'])
  56. def save_staff_profile():
  57. staff_id = request.json.get('staff_id')
  58. staff_profile = request.json.get('staff_profile')
  59. if not staff_id:
  60. return wrap_response(400, msg='staff id is required')
  61. if not staff_profile:
  62. return wrap_response(400, msg='profile is required')
  63. else:
  64. try:
  65. profile_info_list = json.loads(staff_profile)
  66. profile_dict = {item['field_name']: item['field_value'] for item in profile_info_list}
  67. affected_rows = app.staff_manager.save_staff_profile(staff_id, profile_dict)
  68. if not affected_rows:
  69. return wrap_response(500, msg='save staff profile failed')
  70. else:
  71. return wrap_response(200, msg='save staff profile success')
  72. except json.decoder.JSONDecodeError:
  73. return wrap_response(400, msg='profile is not a valid json')
  74. @app.route('/api/getProfileFields', methods=['GET'])
  75. def get_profile_fields():
  76. return wrap_response(200, data=apollo_config.get_json_value("field_map_list", []))
  77. @app.route('/api/getUserProfile', methods=['GET'])
  78. def get_user_profile():
  79. user_id = request.args['user_id']
  80. profile = app.user_manager.get_user_profile(user_id)
  81. if not profile:
  82. resp = {
  83. 'code': 404,
  84. 'msg': 'user not found'
  85. }
  86. else:
  87. resp = {
  88. 'code': 200,
  89. 'msg': 'success',
  90. 'data': profile
  91. }
  92. return jsonify(resp)
  93. @app.route('/api/listUsers', methods=['GET'])
  94. def list_users():
  95. user_name = request.args.get('user_name', None)
  96. user_union_id = request.args.get('user_union_id', None)
  97. if not user_name and not user_union_id:
  98. resp = {
  99. 'code': 400,
  100. 'msg': 'user_name or user_union_id is required'
  101. }
  102. return jsonify(resp)
  103. data = app.user_manager.list_users(user_name=user_name, user_union_id=user_union_id)
  104. return jsonify({'code': 200, 'data': data})
  105. @app.route('/api/getDialogueHistory', methods=['GET'])
  106. def get_dialogue_history():
  107. staff_id = request.args['staff_id']
  108. user_id = request.args['user_id']
  109. recent_minutes = int(request.args.get('recent_minutes', 1440))
  110. dialogue_history = app.history_dialogue_service.get_dialogue_history(staff_id, user_id, recent_minutes)
  111. return jsonify({'code': 200, 'data': dialogue_history})
  112. @app.route('/api/listModels', methods=['GET'])
  113. def list_models():
  114. models = [
  115. {
  116. 'model_type': 'openai_compatible',
  117. 'model_name': chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
  118. 'display_name': 'DeepSeek V3 on 火山'
  119. },
  120. {
  121. 'model_type': 'openai_compatible',
  122. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  123. 'display_name': '豆包Pro 32K'
  124. },
  125. {
  126. 'model_type': 'openai_compatible',
  127. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  128. 'display_name': '豆包Pro 1.5'
  129. },
  130. {
  131. 'model_type': 'openai_compatible',
  132. 'model_name': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH,
  133. 'display_name': 'DeepSeek V3联网 on 火山'
  134. },
  135. {
  136. 'model_type': 'openai_compatible',
  137. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
  138. 'display_name': '豆包1.5视觉理解Pro'
  139. },
  140. ]
  141. return wrap_response(200, data=models)
  142. @app.route('/api/listScenes', methods=['GET'])
  143. def list_scenes():
  144. scenes = [
  145. {'scene': 'greeting', 'display_name': '问候'},
  146. {'scene': 'chitchat', 'display_name': '闲聊'},
  147. {'scene': 'profile_extractor', 'display_name': '画像提取'},
  148. {'scene': 'response_type_detector', 'display_name': '回复模态判断'},
  149. {'scene': 'custom_debugging', 'display_name': '自定义调试场景'}
  150. ]
  151. return wrap_response(200, data=scenes)
  152. @app.route('/api/getBasePrompt', methods=['GET'])
  153. def get_base_prompt():
  154. scene = request.args['scene']
  155. prompt_map = {
  156. 'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
  157. 'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
  158. 'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT,
  159. 'response_type_detector': prompt_templates.RESPONSE_TYPE_DETECT_PROMPT,
  160. 'custom_debugging': '',
  161. }
  162. model_map = {
  163. 'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  164. 'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  165. 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  166. 'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  167. 'custom_debugging': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH
  168. }
  169. if scene not in prompt_map:
  170. return wrap_response(404, msg='scene not found')
  171. data = {
  172. 'model_name': model_map[scene],
  173. 'content': prompt_map[scene]
  174. }
  175. return wrap_response(200, data=data)
  176. @app.route('/api/runPrompt', methods=['POST'])
  177. def run_prompt():
  178. try:
  179. req_data = request.json
  180. logger.debug(req_data)
  181. scene = req_data['scene']
  182. if scene == 'profile_extractor':
  183. response = run_extractor_prompt(req_data)
  184. return wrap_response(200, data=response)
  185. elif scene == 'response_type_detector':
  186. response = run_response_type_prompt(req_data)
  187. return wrap_response(200, data=response.choices[0].message.content)
  188. else:
  189. response = run_chat_prompt(req_data)
  190. return wrap_response(200, data=response.choices[0].message.content)
  191. except Exception as e:
  192. logger.error(e)
  193. return wrap_response(500, msg='Error: {}'.format(e))
  194. @app.route('/api/formatForPrompt', methods=['POST'])
  195. def format_data_for_prompt():
  196. try:
  197. req_data = request.json
  198. content = req_data['content']
  199. format_type = req_data['format_type']
  200. if format_type == 'staff_profile':
  201. if not isinstance(content, dict):
  202. return wrap_response(400, msg='staff_profile should be a dict')
  203. response = format_agent_profile(content)
  204. elif format_type == 'user_profile':
  205. if not isinstance(content, dict):
  206. return wrap_response(400, msg='user_profile should be a dict')
  207. response = format_user_profile(content)
  208. elif format_type == 'dialogue':
  209. if not isinstance(content, list):
  210. return wrap_response(400, msg='dialogue should be a list')
  211. from pqai_agent_server.utils.prompt_util import format_dialogue_history
  212. response = format_dialogue_history(content)
  213. else:
  214. return wrap_response(400, msg='Invalid format_type')
  215. return wrap_response(200, data=response)
  216. except Exception as e:
  217. logger.error(e)
  218. return wrap_response(500, msg='Error: {}'.format(e))
  219. @app.route("/api/healthCheck", methods=["GET"])
  220. def health_check():
  221. return wrap_response(200, msg="OK")
  222. @app.route("/api/getStaffSessionSummary", methods=["GET"])
  223. def get_staff_session_summary():
  224. staff_id = request.args.get("staff_id")
  225. status = request.args.get("status", const.DEFAULT_STAFF_STATUS)
  226. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  227. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  228. # check params
  229. try:
  230. page_id = int(page_id)
  231. page_size = int(page_size)
  232. status = int(status)
  233. except Exception as e:
  234. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  235. staff_session_summary = app.session_manager.get_staff_sessions_summary(
  236. staff_id, page_id, page_size, status
  237. )
  238. if not staff_session_summary:
  239. return wrap_response(404, msg="staff not found")
  240. else:
  241. return wrap_response(200, data=staff_session_summary)
  242. @app.route("/api/getStaffSessionList", methods=["GET"])
  243. def get_staff_session_list():
  244. staff_id = request.args.get("staff_id")
  245. if not staff_id:
  246. return wrap_response(404, msg="staff_id is required")
  247. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  248. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  249. # check params
  250. try:
  251. page_id = int(page_id)
  252. page_size = int(page_size)
  253. except Exception as e:
  254. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  255. staff_session_list = app.session_manager.get_staff_session_list(staff_id, page_id, page_size)
  256. if not staff_session_list:
  257. return wrap_response(404, msg="staff not found")
  258. return wrap_response(200, data=staff_session_list)
  259. @app.route("/api/getStaffList", methods=["GET"])
  260. def get_staff_list():
  261. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  262. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  263. # check params
  264. try:
  265. page_id = int(page_id)
  266. page_size = int(page_size)
  267. except Exception as e:
  268. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  269. staff_list = app.user_manager.get_staff_list(page_id, page_size)
  270. if not staff_list:
  271. return wrap_response(404, msg="staff not found")
  272. return wrap_response(200, data=staff_list)
  273. @app.route("/api/getConversationList", methods=["GET"])
  274. def get_conversation_list():
  275. """
  276. 获取staff && user 私聊对话列表
  277. :return:
  278. """
  279. staff_id = request.args.get("staff_id")
  280. user_id = request.args.get("user_id")
  281. if not staff_id or not user_id:
  282. return wrap_response(404, msg="staff_id and user_id are required")
  283. page = request.args.get("page_id")
  284. response = app.session_manager.get_conversation_list(staff_id, user_id, page, const.DEFAULT_CONVERSATION_SIZE)
  285. return wrap_response(200, data=response)
  286. @app.route("/api/sendMessage", methods=["POST"])
  287. def send_message():
  288. return wrap_response(200, msg="暂不实现功能")
  289. @app.route("/api/quitHumanInterventionStatus", methods=["POST"])
  290. def quit_human_interventions_status():
  291. """
  292. 退出人工介入状态
  293. :return:
  294. """
  295. req_data = request.json
  296. staff_id = req_data["staff_id"]
  297. user_id = req_data["user_id"]
  298. if not user_id or not staff_id:
  299. return wrap_response(404, msg="user_id and staff_id are required")
  300. staff_id = '1688854492669990'
  301. response = quit_human_intervention_status(user_id, staff_id)
  302. return wrap_response(200, data=response)
  303. @app.errorhandler(werkzeug.exceptions.BadRequest)
  304. def handle_bad_request(e):
  305. logger.error(e)
  306. return wrap_response(400, msg='Bad Request: {}'.format(e.description))
  307. if __name__ == '__main__':
  308. parser = ArgumentParser()
  309. parser.add_argument('--prod', action='store_true')
  310. parser.add_argument('--host', default='127.0.0.1')
  311. parser.add_argument('--port', type=int, default=8083)
  312. parser.add_argument('--log-level', default='INFO')
  313. args = parser.parse_args()
  314. config = configs.get()
  315. logging_level = logging.getLevelName(args.log_level)
  316. logging_service.setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
  317. # set db config
  318. user_db_config = config['storage']['user']
  319. staff_db_config = config['storage']['staff']
  320. agent_state_db_config = config['storage']['agent_state']
  321. chat_history_db_config = config['storage']['chat_history']
  322. # init user manager
  323. user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
  324. app.user_manager = user_manager
  325. # init session manager
  326. session_manager = MySQLSessionManager(
  327. db_config=user_db_config['mysql'],
  328. staff_table=staff_db_config['table'],
  329. user_table=user_db_config['table'],
  330. agent_state_table=agent_state_db_config['table'],
  331. chat_history_table=chat_history_db_config['table']
  332. )
  333. app.session_manager = session_manager
  334. # init staff manager
  335. staff_manager = MySQLStaffManager(
  336. db_config=user_db_config['mysql'],
  337. staff_table=staff_db_config['table'],
  338. user_table=user_db_config['table']
  339. )
  340. app.staff_manager = staff_manager
  341. # init wecom manager
  342. wecom_db_config = config['storage']['user_relation']
  343. user_relation_manager = MySQLUserRelationManager(
  344. user_db_config['mysql'], wecom_db_config['mysql'],
  345. config['storage']['staff']['table'],
  346. user_db_config['table'],
  347. wecom_db_config['table']['staff'],
  348. wecom_db_config['table']['relation'],
  349. wecom_db_config['table']['user']
  350. )
  351. app.user_relation_manager = user_relation_manager
  352. app.history_dialogue_service = HistoryDialogueService(
  353. config['storage']['history_dialogue']['api_base_url']
  354. )
  355. app.run(debug=not args.prod, host=args.host, port=args.port)