api_server.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import time
  5. import logging
  6. import werkzeug.exceptions
  7. from flask import Flask, request, jsonify
  8. from argparse import ArgumentParser
  9. from sqlalchemy.orm import sessionmaker
  10. from pqai_agent import configs
  11. from pqai_agent import logging_service, chat_service, prompt_templates
  12. from pqai_agent.agents.message_reply_agent import MessageReplyAgent
  13. from pqai_agent.data_models.agent_configuration import AgentConfiguration
  14. from pqai_agent.data_models.service_module import ServiceModule
  15. from pqai_agent.history_dialogue_service import HistoryDialogueService
  16. from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
  17. from pqai_agent.utils.db_utils import create_sql_engine
  18. from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
  19. from pqai_agent_server.const import AgentApiConst
  20. from pqai_agent_server.models import MySQLSessionManager
  21. from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
  22. from pqai_agent_server.utils import (
  23. run_extractor_prompt,
  24. run_chat_prompt,
  25. run_response_type_prompt,
  26. )
  27. app = Flask('agent_api_server')
  28. logger = logging_service.logger
  29. const = AgentApiConst()
  30. @app.route('/api/listStaffs', methods=['GET'])
  31. def list_staffs():
  32. staff_data = app.user_relation_manager.list_staffs()
  33. return wrap_response(200, data=staff_data)
  34. @app.route('/api/getStaffProfile', methods=['GET'])
  35. def get_staff_profile():
  36. staff_id = request.args['staff_id']
  37. profile = app.user_manager.get_staff_profile(staff_id)
  38. if not profile:
  39. return wrap_response(404, msg='staff not found')
  40. else:
  41. return wrap_response(200, data=profile)
  42. @app.route('/api/getUserProfile', methods=['GET'])
  43. def get_user_profile():
  44. user_id = request.args['user_id']
  45. profile = app.user_manager.get_user_profile(user_id)
  46. if not profile:
  47. resp = {
  48. 'code': 404,
  49. 'msg': 'user not found'
  50. }
  51. else:
  52. resp = {
  53. 'code': 200,
  54. 'msg': 'success',
  55. 'data': profile
  56. }
  57. return jsonify(resp)
  58. @app.route('/api/listUsers', methods=['GET'])
  59. def list_users():
  60. user_name = request.args.get('user_name', None)
  61. user_union_id = request.args.get('user_union_id', None)
  62. if not user_name and not user_union_id:
  63. resp = {
  64. 'code': 400,
  65. 'msg': 'user_name or user_union_id is required'
  66. }
  67. return jsonify(resp)
  68. data = app.user_manager.list_users(user_name=user_name, user_union_id=user_union_id)
  69. return jsonify({'code': 200, 'data': data})
  70. @app.route('/api/getDialogueHistory', methods=['GET'])
  71. def get_dialogue_history():
  72. staff_id = request.args['staff_id']
  73. user_id = request.args['user_id']
  74. recent_minutes = int(request.args.get('recent_minutes', 1440))
  75. dialogue_history = app.history_dialogue_service.get_dialogue_history(staff_id, user_id, recent_minutes)
  76. return jsonify({'code': 200, 'data': dialogue_history})
  77. @app.route('/api/listModels', methods=['GET'])
  78. def list_models():
  79. models = [
  80. {
  81. 'model_type': 'openai_compatible',
  82. 'model_name': chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
  83. 'display_name': 'DeepSeek V3 on 火山'
  84. },
  85. {
  86. 'model_type': 'openai_compatible',
  87. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  88. 'display_name': '豆包Pro 32K'
  89. },
  90. {
  91. 'model_type': 'openai_compatible',
  92. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  93. 'display_name': '豆包Pro 1.5'
  94. },
  95. {
  96. 'model_type': 'openai_compatible',
  97. 'model_name': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH,
  98. 'display_name': 'DeepSeek V3联网 on 火山'
  99. },
  100. {
  101. 'model_type': 'openai_compatible',
  102. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
  103. 'display_name': '豆包1.5视觉理解Pro'
  104. },
  105. ]
  106. return wrap_response(200, data=models)
  107. @app.route('/api/listScenes', methods=['GET'])
  108. def list_scenes():
  109. scenes = [
  110. {'scene': 'greeting', 'display_name': '问候'},
  111. {'scene': 'chitchat', 'display_name': '闲聊'},
  112. {'scene': 'profile_extractor', 'display_name': '画像提取'},
  113. {'scene': 'response_type_detector', 'display_name': '回复模态判断'},
  114. {'scene': 'custom_debugging', 'display_name': '自定义调试场景'}
  115. ]
  116. return wrap_response(200, data=scenes)
  117. @app.route('/api/getBasePrompt', methods=['GET'])
  118. def get_base_prompt():
  119. scene = request.args['scene']
  120. prompt_map = {
  121. 'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
  122. 'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
  123. 'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT,
  124. 'response_type_detector': prompt_templates.RESPONSE_TYPE_DETECT_PROMPT,
  125. 'custom_debugging': '',
  126. }
  127. model_map = {
  128. 'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  129. 'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  130. 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  131. 'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  132. 'custom_debugging': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH
  133. }
  134. if scene not in prompt_map:
  135. return wrap_response(404, msg='scene not found')
  136. data = {
  137. 'model_name': model_map[scene],
  138. 'content': prompt_map[scene]
  139. }
  140. return wrap_response(200, data=data)
  141. @app.route('/api/runPrompt', methods=['POST'])
  142. def run_prompt():
  143. try:
  144. req_data = request.json
  145. logger.debug(req_data)
  146. scene = req_data['scene']
  147. if scene == 'profile_extractor':
  148. response = run_extractor_prompt(req_data)
  149. return wrap_response(200, data=response)
  150. elif scene == 'response_type_detector':
  151. response = run_response_type_prompt(req_data)
  152. return wrap_response(200, data=response.choices[0].message.content)
  153. else:
  154. response = run_chat_prompt(req_data)
  155. return wrap_response(200, data=response.choices[0].message.content)
  156. except Exception as e:
  157. logger.error(e)
  158. return wrap_response(500, msg='Error: {}'.format(e))
  159. @app.route('/api/formatForPrompt', methods=['POST'])
  160. def format_data_for_prompt():
  161. try:
  162. req_data = request.json
  163. content = req_data['content']
  164. format_type = req_data['format_type']
  165. if format_type == 'staff_profile':
  166. if not isinstance(content, dict):
  167. return wrap_response(400, msg='staff_profile should be a dict')
  168. response = format_agent_profile(content)
  169. elif format_type == 'user_profile':
  170. if not isinstance(content, dict):
  171. return wrap_response(400, msg='user_profile should be a dict')
  172. response = format_user_profile(content)
  173. elif format_type == 'dialogue':
  174. if not isinstance(content, list):
  175. return wrap_response(400, msg='dialogue should be a list')
  176. from pqai_agent_server.utils.prompt_util import format_dialogue_history
  177. response = format_dialogue_history(content)
  178. else:
  179. return wrap_response(400, msg='Invalid format_type')
  180. return wrap_response(200, data=response)
  181. except Exception as e:
  182. logger.error(e)
  183. return wrap_response(500, msg='Error: {}'.format(e))
  184. @app.route("/api/healthCheck", methods=["GET"])
  185. def health_check():
  186. return wrap_response(200, msg="OK")
  187. @app.route("/api/getStaffSessionSummary", methods=["GET"])
  188. def get_staff_session_summary():
  189. staff_id = request.args.get("staff_id")
  190. status = request.args.get("status", const.DEFAULT_STAFF_STATUS)
  191. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  192. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  193. # check params
  194. try:
  195. page_id = int(page_id)
  196. page_size = int(page_size)
  197. status = int(status)
  198. except Exception as e:
  199. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  200. staff_session_summary = app.session_manager.get_staff_sessions_summary(
  201. staff_id, page_id, page_size, status
  202. )
  203. if not staff_session_summary:
  204. return wrap_response(404, msg="staff not found")
  205. else:
  206. return wrap_response(200, data=staff_session_summary)
  207. @app.route("/api/getStaffSessionList", methods=["GET"])
  208. def get_staff_session_list():
  209. staff_id = request.args.get("staff_id")
  210. if not staff_id:
  211. return wrap_response(404, msg="staff_id is required")
  212. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  213. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  214. # check params
  215. try:
  216. page_id = int(page_id)
  217. page_size = int(page_size)
  218. except Exception as e:
  219. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  220. staff_session_list = app.session_manager.get_staff_session_list(staff_id, page_id, page_size)
  221. if not staff_session_list:
  222. return wrap_response(404, msg="staff not found")
  223. return wrap_response(200, data=staff_session_list)
  224. @app.route("/api/getStaffList", methods=["GET"])
  225. def get_staff_list():
  226. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  227. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  228. # check params
  229. try:
  230. page_id = int(page_id)
  231. page_size = int(page_size)
  232. except Exception as e:
  233. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  234. staff_list = app.user_manager.get_staff_list(page_id, page_size)
  235. if not staff_list:
  236. return wrap_response(404, msg="staff not found")
  237. return wrap_response(200, data=staff_list)
  238. @app.route("/api/getConversationList", methods=["GET"])
  239. def get_conversation_list():
  240. """
  241. 获取staff && user 私聊对话列表
  242. :return:
  243. """
  244. staff_id = request.args.get("staff_id")
  245. user_id = request.args.get("user_id")
  246. if not staff_id or not user_id:
  247. return wrap_response(404, msg="staff_id and user_id are required")
  248. page = request.args.get("page")
  249. response = app.session_manager.get_conversation_list(staff_id, user_id, page, const.DEFAULT_CONVERSATION_SIZE)
  250. return wrap_response(200, data=response)
  251. @app.route("/api/sendMessage", methods=["POST"])
  252. def send_message():
  253. return wrap_response(200, msg="暂不实现功能")
  254. @app.route("/api/quitHumanInterventionStatus", methods=["POST"])
  255. def quit_human_interventions_status():
  256. """
  257. 退出人工介入状态
  258. :return:
  259. """
  260. req_data = request.json
  261. staff_id = req_data["staff_id"]
  262. user_id = req_data["user_id"]
  263. if not user_id or not staff_id:
  264. return wrap_response(404, msg="user_id and staff_id are required")
  265. response = quit_human_intervention_status(user_id, staff_id)
  266. return wrap_response(200, data=response)
  267. ## Agent管理接口
  268. @app.route("/api/getNativeAgentList", methods=["GET"])
  269. def get_native_agent_list():
  270. """
  271. 获取所有的Agent列表
  272. :return:
  273. """
  274. page = request.args.get('page', 1)
  275. page_size = request.args.get('page_size', 50)
  276. create_user = request.args.get('create_user', None)
  277. update_user = request.args.get('update_user', None)
  278. offset = (int(page) - 1) * int(page_size)
  279. with app.session_maker() as session:
  280. query = session.query(AgentConfiguration) \
  281. .filter(AgentConfiguration.is_delete == 0)
  282. if create_user:
  283. query = query.filter(AgentConfiguration.create_user == create_user)
  284. if update_user:
  285. query = query.filter(AgentConfiguration.update_user == update_user)
  286. query = query.offset(offset).limit(int(page_size))
  287. data = query.all()
  288. ret_data = [
  289. {
  290. 'id': agent.id,
  291. 'name': agent.name,
  292. 'display_name': agent.display_name,
  293. 'type': agent.type,
  294. 'execution_model': agent.execution_model,
  295. 'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  296. 'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
  297. }
  298. for agent in data
  299. ]
  300. return wrap_response(200, data=ret_data)
  301. @app.route("/api/getNativeAgentConfiguration", methods=["GET"])
  302. def get_native_agent_configuration():
  303. """
  304. 获取指定Agent的配置
  305. :return:
  306. """
  307. agent_id = request.args.get('agent_id')
  308. if not agent_id:
  309. return wrap_response(404, msg='agent_id is required')
  310. with app.session_maker() as session:
  311. agent = session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).first()
  312. if not agent:
  313. return wrap_response(404, msg='Agent not found')
  314. data = {
  315. 'id': agent.id,
  316. 'name': agent.name,
  317. 'display_name': agent.display_name,
  318. 'type': agent.type,
  319. 'execution_model': agent.execution_model,
  320. 'system_prompt': agent.system_prompt,
  321. 'task_prompt': agent.task_prompt,
  322. 'tools': agent.tools,
  323. 'sub_agents': agent.sub_agents,
  324. 'extra_params': agent.extra_params,
  325. 'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  326. 'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
  327. }
  328. return wrap_response(200, data=data)
  329. @app.route("/api/saveNativeAgentConfiguration", methods=["POST"])
  330. def save_native_agent_configuration():
  331. """
  332. 保存Agent配置
  333. :return:
  334. """
  335. req_data = request.json
  336. agent_id = req_data.get('agent_id', None)
  337. name = req_data.get('name')
  338. display_name = req_data.get('display_name', None)
  339. type_ = req_data.get('type', 0)
  340. execution_model = req_data.get('execution_model', None)
  341. system_prompt = req_data.get('system_prompt', None)
  342. task_prompt = req_data.get('task_prompt', None)
  343. tools = req_data.get('tools', [])
  344. sub_agents = req_data.get('sub_agents', [])
  345. extra_params = req_data.get('extra_params', {})
  346. if not name:
  347. return wrap_response(400, msg='name is required')
  348. with app.session_maker() as session:
  349. if agent_id:
  350. agent_id = int(agent_id)
  351. agent = session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).first()
  352. if not agent:
  353. return wrap_response(404, msg='Agent not found')
  354. agent.name = name
  355. agent.display_name = display_name
  356. agent.type = type_
  357. agent.execution_model = execution_model
  358. agent.system_prompt = system_prompt
  359. agent.task_prompt = task_prompt
  360. agent.tools = tools
  361. agent.sub_agents = sub_agents
  362. agent.extra_params = extra_params
  363. else:
  364. agent = AgentConfiguration(
  365. name=name,
  366. display_name=display_name,
  367. type=type_,
  368. execution_model=execution_model,
  369. system_prompt=system_prompt,
  370. task_prompt=task_prompt,
  371. tools=tools,
  372. sub_agents=sub_agents,
  373. extra_params=extra_params
  374. )
  375. session.add(agent)
  376. session.commit()
  377. return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
  378. @app.route("/api/getModuleList", methods=["GET"])
  379. def get_module_list():
  380. """
  381. 获取所有的模块列表
  382. :return:
  383. """
  384. with app.session_maker() as session:
  385. query = session.query(ServiceModule) \
  386. .filter(ServiceModule.is_delete == 0)
  387. data = query.all()
  388. ret_data = [
  389. {
  390. 'id': module.id,
  391. 'name': module.name,
  392. 'display_name': module.display_name,
  393. 'default_agent_type': module.default_agent_type,
  394. 'default_agent_id': module.default_agent_id,
  395. 'create_time': module.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  396. 'update_time': module.update_time.strftime('%Y-%m-%d %H:%M:%S')
  397. }
  398. for module in data
  399. ]
  400. return wrap_response(200, data=ret_data)
  401. @app.route("/api/getModuleConfiguration", methods=["GET"])
  402. def get_module_configuration():
  403. """
  404. 获取指定模块的配置
  405. :return:
  406. """
  407. module_id = request.args.get('module_id')
  408. if not module_id:
  409. return wrap_response(404, msg='module_id is required')
  410. with app.session_maker() as session:
  411. module = session.query(ServiceModule).filter(ServiceModule.id == module_id).first()
  412. if not module:
  413. return wrap_response(404, msg='Module not found')
  414. data = {
  415. 'id': module.id,
  416. 'name': module.name,
  417. 'display_name': module.display_name,
  418. 'default_agent_type': module.default_agent_type,
  419. 'default_agent_id': module.default_agent_id,
  420. 'create_time': module.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  421. 'updated_time': module.updated_time.strftime('%Y-%m-%d %H:%M:%S')
  422. }
  423. return wrap_response(200, data=data)
  424. @app.route("/api/saveModuleConfiguration", methods=["POST"])
  425. def save_module_configuration():
  426. """
  427. 保存模块配置
  428. :return:
  429. """
  430. req_data = request.json
  431. module_id = req_data.get('module_id', None)
  432. name = req_data.get('name')
  433. display_name = req_data.get('display_name', None)
  434. default_agent_type = req_data.get('default_agent_type', 0)
  435. default_agent_id = req_data.get('default_agent_id', None)
  436. if not name:
  437. return wrap_response(400, msg='name is required')
  438. with app.session_maker() as session:
  439. if module_id:
  440. module_id = int(module_id)
  441. module = session.query(ServiceModule).filter(ServiceModule.id == module_id).first()
  442. if not module:
  443. return wrap_response(404, msg='Module not found')
  444. module.name = name
  445. module.display_name = display_name
  446. module.default_agent_type = default_agent_type
  447. module.default_agent_id = default_agent_id
  448. else:
  449. module = ServiceModule(
  450. name=name,
  451. display_name=display_name,
  452. default_agent_type=default_agent_type,
  453. default_agent_id=default_agent_id
  454. )
  455. session.add(module)
  456. session.commit()
  457. return wrap_response(200, msg='Module configuration saved successfully', data={'id': module.id})
  458. @app.errorhandler(werkzeug.exceptions.BadRequest)
  459. def handle_bad_request(e):
  460. logger.error(e)
  461. return wrap_response(400, msg='Bad Request: {}'.format(e.description))
  462. if __name__ == '__main__':
  463. parser = ArgumentParser()
  464. parser.add_argument('--prod', action='store_true')
  465. parser.add_argument('--host', default='127.0.0.1')
  466. parser.add_argument('--port', type=int, default=8083)
  467. parser.add_argument('--log-level', default='INFO')
  468. args = parser.parse_args()
  469. config = configs.get()
  470. logging_level = logging.getLevelName(args.log_level)
  471. logging_service.setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
  472. # set db config
  473. user_db_config = config['storage']['user']
  474. staff_db_config = config['storage']['staff']
  475. agent_state_db_config = config['storage']['agent_state']
  476. chat_history_db_config = config['storage']['chat_history']
  477. # init user manager
  478. user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
  479. app.user_manager = user_manager
  480. # init session manager
  481. session_manager = MySQLSessionManager(
  482. db_config=user_db_config['mysql'],
  483. staff_table=staff_db_config['table'],
  484. user_table=user_db_config['table'],
  485. agent_state_table=agent_state_db_config['table'],
  486. chat_history_table=chat_history_db_config['table']
  487. )
  488. app.session_manager = session_manager
  489. agent_db_engine = create_sql_engine(config['storage']['agent_state']['mysql'])
  490. app.session_maker = sessionmaker(bind=agent_db_engine)
  491. wecom_db_config = config['storage']['user_relation']
  492. user_relation_manager = MySQLUserRelationManager(
  493. user_db_config['mysql'], wecom_db_config['mysql'],
  494. config['storage']['staff']['table'],
  495. user_db_config['table'],
  496. wecom_db_config['table']['staff'],
  497. wecom_db_config['table']['relation'],
  498. wecom_db_config['table']['user']
  499. )
  500. app.user_relation_manager = user_relation_manager
  501. app.history_dialogue_service = HistoryDialogueService(
  502. config['storage']['history_dialogue']['api_base_url']
  503. )
  504. app.run(debug=not args.prod, host=args.host, port=args.port)