api_server.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import time
  5. import logging
  6. import werkzeug.exceptions
  7. from flask import Flask, request, jsonify
  8. from argparse import ArgumentParser
  9. from sqlalchemy.orm import sessionmaker
  10. from pqai_agent import configs
  11. from pqai_agent import logging_service, chat_service, prompt_templates
  12. from pqai_agent.agents.message_reply_agent import MessageReplyAgent
  13. from pqai_agent.data_models.agent_configuration import AgentConfiguration
  14. from pqai_agent.data_models.service_module import ServiceModule
  15. from pqai_agent.history_dialogue_service import HistoryDialogueService
  16. from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
  17. from pqai_agent.utils.db_utils import create_sql_engine
  18. from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
  19. from pqai_agent_server.const import AgentApiConst
  20. from pqai_agent_server.const.status_enum import TestTaskStatus
  21. from pqai_agent_server.models import MySQLSessionManager
  22. from pqai_agent_server.task_server import TaskManager
  23. from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
  24. from pqai_agent_server.utils import (
  25. run_extractor_prompt,
  26. run_chat_prompt,
  27. run_response_type_prompt,
  28. )
  29. app = Flask('agent_api_server')
  30. logger = logging_service.logger
  31. const = AgentApiConst()
  32. @app.route('/api/listStaffs', methods=['GET'])
  33. def list_staffs():
  34. staff_data = app.user_relation_manager.list_staffs()
  35. return wrap_response(200, data=staff_data)
  36. @app.route('/api/getStaffProfile', methods=['GET'])
  37. def get_staff_profile():
  38. staff_id = request.args['staff_id']
  39. profile = app.user_manager.get_staff_profile(staff_id)
  40. if not profile:
  41. return wrap_response(404, msg='staff not found')
  42. else:
  43. return wrap_response(200, data=profile)
  44. @app.route('/api/getUserProfile', methods=['GET'])
  45. def get_user_profile():
  46. user_id = request.args['user_id']
  47. profile = app.user_manager.get_user_profile(user_id)
  48. if not profile:
  49. resp = {
  50. 'code': 404,
  51. 'msg': 'user not found'
  52. }
  53. else:
  54. resp = {
  55. 'code': 200,
  56. 'msg': 'success',
  57. 'data': profile
  58. }
  59. return jsonify(resp)
  60. @app.route('/api/listUsers', methods=['GET'])
  61. def list_users():
  62. user_name = request.args.get('user_name', None)
  63. user_union_id = request.args.get('user_union_id', None)
  64. if not user_name and not user_union_id:
  65. resp = {
  66. 'code': 400,
  67. 'msg': 'user_name or user_union_id is required'
  68. }
  69. return jsonify(resp)
  70. data = app.user_manager.list_users(user_name=user_name, user_union_id=user_union_id)
  71. return jsonify({'code': 200, 'data': data})
  72. @app.route('/api/getDialogueHistory', methods=['GET'])
  73. def get_dialogue_history():
  74. staff_id = request.args['staff_id']
  75. user_id = request.args['user_id']
  76. recent_minutes = int(request.args.get('recent_minutes', 1440))
  77. dialogue_history = app.history_dialogue_service.get_dialogue_history(staff_id, user_id, recent_minutes)
  78. return jsonify({'code': 200, 'data': dialogue_history})
  79. @app.route('/api/listModels', methods=['GET'])
  80. def list_models():
  81. models = [
  82. {
  83. 'model_type': 'openai_compatible',
  84. 'model_name': chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
  85. 'display_name': 'DeepSeek V3 on 火山'
  86. },
  87. {
  88. 'model_type': 'openai_compatible',
  89. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  90. 'display_name': '豆包Pro 32K'
  91. },
  92. {
  93. 'model_type': 'openai_compatible',
  94. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  95. 'display_name': '豆包Pro 1.5'
  96. },
  97. {
  98. 'model_type': 'openai_compatible',
  99. 'model_name': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH,
  100. 'display_name': 'DeepSeek V3联网 on 火山'
  101. },
  102. {
  103. 'model_type': 'openai_compatible',
  104. 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
  105. 'display_name': '豆包1.5视觉理解Pro'
  106. },
  107. ]
  108. return wrap_response(200, data=models)
  109. @app.route('/api/listScenes', methods=['GET'])
  110. def list_scenes():
  111. scenes = [
  112. {'scene': 'greeting', 'display_name': '问候'},
  113. {'scene': 'chitchat', 'display_name': '闲聊'},
  114. {'scene': 'profile_extractor', 'display_name': '画像提取'},
  115. {'scene': 'response_type_detector', 'display_name': '回复模态判断'},
  116. {'scene': 'custom_debugging', 'display_name': '自定义调试场景'}
  117. ]
  118. return wrap_response(200, data=scenes)
  119. @app.route('/api/getBasePrompt', methods=['GET'])
  120. def get_base_prompt():
  121. scene = request.args['scene']
  122. prompt_map = {
  123. 'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
  124. 'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
  125. 'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT_V2,
  126. 'response_type_detector': prompt_templates.RESPONSE_TYPE_DETECT_PROMPT,
  127. 'custom_debugging': '',
  128. }
  129. model_map = {
  130. 'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  131. 'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  132. 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  133. 'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
  134. 'custom_debugging': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH
  135. }
  136. if scene not in prompt_map:
  137. return wrap_response(404, msg='scene not found')
  138. data = {
  139. 'model_name': model_map[scene],
  140. 'content': prompt_map[scene]
  141. }
  142. return wrap_response(200, data=data)
  143. @app.route('/api/runPrompt', methods=['POST'])
  144. def run_prompt():
  145. try:
  146. req_data = request.json
  147. logger.debug(req_data)
  148. scene = req_data['scene']
  149. if scene == 'profile_extractor':
  150. response = run_extractor_prompt(req_data)
  151. return wrap_response(200, data=response)
  152. elif scene == 'response_type_detector':
  153. response = run_response_type_prompt(req_data)
  154. return wrap_response(200, data=response.choices[0].message.content)
  155. else:
  156. response = run_chat_prompt(req_data)
  157. return wrap_response(200, data=response.choices[0].message.content)
  158. except Exception as e:
  159. logger.error(e)
  160. return wrap_response(500, msg='Error: {}'.format(e))
  161. @app.route('/api/formatForPrompt', methods=['POST'])
  162. def format_data_for_prompt():
  163. try:
  164. req_data = request.json
  165. content = req_data['content']
  166. format_type = req_data['format_type']
  167. if format_type == 'staff_profile':
  168. if not isinstance(content, dict):
  169. return wrap_response(400, msg='staff_profile should be a dict')
  170. response = format_agent_profile(content)
  171. elif format_type == 'user_profile':
  172. if not isinstance(content, dict):
  173. return wrap_response(400, msg='user_profile should be a dict')
  174. response = format_user_profile(content)
  175. elif format_type == 'dialogue':
  176. if not isinstance(content, list):
  177. return wrap_response(400, msg='dialogue should be a list')
  178. from pqai_agent_server.utils.prompt_util import format_dialogue_history
  179. response = format_dialogue_history(content)
  180. else:
  181. return wrap_response(400, msg='Invalid format_type')
  182. return wrap_response(200, data=response)
  183. except Exception as e:
  184. logger.error(e)
  185. return wrap_response(500, msg='Error: {}'.format(e))
  186. @app.route("/api/healthCheck", methods=["GET"])
  187. def health_check():
  188. return wrap_response(200, msg="OK")
  189. @app.route("/api/getStaffSessionSummary", methods=["GET"])
  190. def get_staff_session_summary():
  191. staff_id = request.args.get("staff_id")
  192. status = request.args.get("status", const.DEFAULT_STAFF_STATUS)
  193. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  194. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  195. # check params
  196. try:
  197. page_id = int(page_id)
  198. page_size = int(page_size)
  199. status = int(status)
  200. except Exception as e:
  201. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  202. staff_session_summary = app.session_manager.get_staff_sessions_summary(
  203. staff_id, page_id, page_size, status
  204. )
  205. if not staff_session_summary:
  206. return wrap_response(404, msg="staff not found")
  207. else:
  208. return wrap_response(200, data=staff_session_summary)
  209. @app.route("/api/getStaffSessionList", methods=["GET"])
  210. def get_staff_session_list():
  211. staff_id = request.args.get("staff_id")
  212. if not staff_id:
  213. return wrap_response(404, msg="staff_id is required")
  214. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  215. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  216. # check params
  217. try:
  218. page_id = int(page_id)
  219. page_size = int(page_size)
  220. except Exception as e:
  221. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  222. staff_session_list = app.session_manager.get_staff_session_list(staff_id, page_id, page_size)
  223. if not staff_session_list:
  224. return wrap_response(404, msg="staff not found")
  225. return wrap_response(200, data=staff_session_list)
  226. @app.route("/api/getStaffList", methods=["GET"])
  227. def get_staff_list():
  228. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  229. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  230. # check params
  231. try:
  232. page_id = int(page_id)
  233. page_size = int(page_size)
  234. except Exception as e:
  235. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  236. staff_list = app.user_manager.get_staff_list(page_id, page_size)
  237. if not staff_list:
  238. return wrap_response(404, msg="staff not found")
  239. return wrap_response(200, data=staff_list)
  240. @app.route("/api/getConversationList", methods=["GET"])
  241. def get_conversation_list():
  242. """
  243. 获取staff && user 私聊对话列表
  244. :return:
  245. """
  246. staff_id = request.args.get("staff_id")
  247. user_id = request.args.get("user_id")
  248. if not staff_id or not user_id:
  249. return wrap_response(404, msg="staff_id and user_id are required")
  250. page = request.args.get("page")
  251. response = app.session_manager.get_conversation_list(staff_id, user_id, page, const.DEFAULT_CONVERSATION_SIZE)
  252. return wrap_response(200, data=response)
  253. @app.route("/api/sendMessage", methods=["POST"])
  254. def send_message():
  255. return wrap_response(200, msg="暂不实现功能")
  256. @app.route("/api/quitHumanInterventionStatus", methods=["POST"])
  257. def quit_human_interventions_status():
  258. """
  259. 退出人工介入状态
  260. :return:
  261. """
  262. req_data = request.json
  263. staff_id = req_data["staff_id"]
  264. user_id = req_data["user_id"]
  265. if not user_id or not staff_id:
  266. return wrap_response(404, msg="user_id and staff_id are required")
  267. response = quit_human_intervention_status(user_id, staff_id)
  268. return wrap_response(200, data=response)
  269. ## Agent管理接口
  270. @app.route("/api/getNativeAgentList", methods=["GET"])
  271. def get_native_agent_list():
  272. """
  273. 获取所有的Agent列表
  274. :return:
  275. """
  276. page = request.args.get('page', 1)
  277. page_size = request.args.get('page_size', 50)
  278. create_user = request.args.get('create_user', None)
  279. update_user = request.args.get('update_user', None)
  280. offset = (int(page) - 1) * int(page_size)
  281. with app.session_maker() as session:
  282. query = session.query(AgentConfiguration) \
  283. .filter(AgentConfiguration.is_delete == 0)
  284. if create_user:
  285. query = query.filter(AgentConfiguration.create_user == create_user)
  286. if update_user:
  287. query = query.filter(AgentConfiguration.update_user == update_user)
  288. query = query.offset(offset).limit(int(page_size))
  289. data = query.all()
  290. ret_data = [
  291. {
  292. 'id': agent.id,
  293. 'name': agent.name,
  294. 'display_name': agent.display_name,
  295. 'type': agent.type,
  296. 'execution_model': agent.execution_model,
  297. 'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  298. 'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
  299. }
  300. for agent in data
  301. ]
  302. return wrap_response(200, data=ret_data)
  303. @app.route("/api/getNativeAgentConfiguration", methods=["GET"])
  304. def get_native_agent_configuration():
  305. """
  306. 获取指定Agent的配置
  307. :return:
  308. """
  309. agent_id = request.args.get('agent_id')
  310. if not agent_id:
  311. return wrap_response(404, msg='agent_id is required')
  312. with app.session_maker() as session:
  313. agent = session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).first()
  314. if not agent:
  315. return wrap_response(404, msg='Agent not found')
  316. data = {
  317. 'id': agent.id,
  318. 'name': agent.name,
  319. 'display_name': agent.display_name,
  320. 'type': agent.type,
  321. 'execution_model': agent.execution_model,
  322. 'system_prompt': agent.system_prompt,
  323. 'task_prompt': agent.task_prompt,
  324. 'tools': agent.tools,
  325. 'sub_agents': agent.sub_agents,
  326. 'extra_params': agent.extra_params,
  327. 'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  328. 'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
  329. }
  330. return wrap_response(200, data=data)
  331. @app.route("/api/saveNativeAgentConfiguration", methods=["POST"])
  332. def save_native_agent_configuration():
  333. """
  334. 保存Agent配置
  335. :return:
  336. """
  337. req_data = request.json
  338. agent_id = req_data.get('agent_id', None)
  339. name = req_data.get('name')
  340. display_name = req_data.get('display_name', None)
  341. type_ = req_data.get('type', 0)
  342. execution_model = req_data.get('execution_model', None)
  343. system_prompt = req_data.get('system_prompt', None)
  344. task_prompt = req_data.get('task_prompt', None)
  345. tools = req_data.get('tools', [])
  346. sub_agents = req_data.get('sub_agents', [])
  347. extra_params = req_data.get('extra_params', {})
  348. if not name:
  349. return wrap_response(400, msg='name is required')
  350. with app.session_maker() as session:
  351. if agent_id:
  352. agent_id = int(agent_id)
  353. agent = session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).first()
  354. if not agent:
  355. return wrap_response(404, msg='Agent not found')
  356. agent.name = name
  357. agent.display_name = display_name
  358. agent.type = type_
  359. agent.execution_model = execution_model
  360. agent.system_prompt = system_prompt
  361. agent.task_prompt = task_prompt
  362. agent.tools = tools
  363. agent.sub_agents = sub_agents
  364. agent.extra_params = extra_params
  365. else:
  366. agent = AgentConfiguration(
  367. name=name,
  368. display_name=display_name,
  369. type=type_,
  370. execution_model=execution_model,
  371. system_prompt=system_prompt,
  372. task_prompt=task_prompt,
  373. tools=tools,
  374. sub_agents=sub_agents,
  375. extra_params=extra_params
  376. )
  377. session.add(agent)
  378. session.commit()
  379. return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
  380. @app.route("/api/getModuleList", methods=["GET"])
  381. def get_module_list():
  382. """
  383. 获取所有的模块列表
  384. :return:
  385. """
  386. with app.session_maker() as session:
  387. query = session.query(ServiceModule) \
  388. .filter(ServiceModule.is_delete == 0)
  389. data = query.all()
  390. ret_data = [
  391. {
  392. 'id': module.id,
  393. 'name': module.name,
  394. 'display_name': module.display_name,
  395. 'default_agent_type': module.default_agent_type,
  396. 'default_agent_id': module.default_agent_id,
  397. 'create_time': module.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  398. 'update_time': module.update_time.strftime('%Y-%m-%d %H:%M:%S')
  399. }
  400. for module in data
  401. ]
  402. return wrap_response(200, data=ret_data)
  403. @app.route("/api/getModuleConfiguration", methods=["GET"])
  404. def get_module_configuration():
  405. """
  406. 获取指定模块的配置
  407. :return:
  408. """
  409. module_id = request.args.get('module_id')
  410. if not module_id:
  411. return wrap_response(404, msg='module_id is required')
  412. with app.session_maker() as session:
  413. module = session.query(ServiceModule).filter(ServiceModule.id == module_id).first()
  414. if not module:
  415. return wrap_response(404, msg='Module not found')
  416. data = {
  417. 'id': module.id,
  418. 'name': module.name,
  419. 'display_name': module.display_name,
  420. 'default_agent_type': module.default_agent_type,
  421. 'default_agent_id': module.default_agent_id,
  422. 'create_time': module.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  423. 'updated_time': module.updated_time.strftime('%Y-%m-%d %H:%M:%S')
  424. }
  425. return wrap_response(200, data=data)
  426. @app.route("/api/saveModuleConfiguration", methods=["POST"])
  427. def save_module_configuration():
  428. """
  429. 保存模块配置
  430. :return:
  431. """
  432. req_data = request.json
  433. module_id = req_data.get('module_id', None)
  434. name = req_data.get('name')
  435. display_name = req_data.get('display_name', None)
  436. default_agent_type = req_data.get('default_agent_type', 0)
  437. default_agent_id = req_data.get('default_agent_id', None)
  438. if not name:
  439. return wrap_response(400, msg='name is required')
  440. with app.session_maker() as session:
  441. if module_id:
  442. module_id = int(module_id)
  443. module = session.query(ServiceModule).filter(ServiceModule.id == module_id).first()
  444. if not module:
  445. return wrap_response(404, msg='Module not found')
  446. module.name = name
  447. module.display_name = display_name
  448. module.default_agent_type = default_agent_type
  449. module.default_agent_id = default_agent_id
  450. else:
  451. module = ServiceModule(
  452. name=name,
  453. display_name=display_name,
  454. default_agent_type=default_agent_type,
  455. default_agent_id=default_agent_id
  456. )
  457. session.add(module)
  458. session.commit()
  459. return wrap_response(200, msg='Module configuration saved successfully', data={'id': module.id})
  460. @app.route("/api/getTestTaskList", methods=["GET"])
  461. def get_test_task_list():
  462. """
  463. 获取单元测试任务列表
  464. :return:
  465. """
  466. page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
  467. page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
  468. try:
  469. page_num = int(page_num)
  470. page_size = int(page_size)
  471. except Exception as e:
  472. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  473. response = app.task_manager.get_test_task_list(page_num, page_size)
  474. return wrap_response(200, data=response)
  475. @app.route("/api/getTestTaskConversations", methods=["GET"])
  476. def get_test_task_conversations():
  477. """
  478. 获取单元测试对话任务列表
  479. :return:
  480. """
  481. task_id = request.args.get("taskId", None)
  482. if not task_id:
  483. return wrap_response(404, msg='task_id is required')
  484. page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
  485. page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
  486. try:
  487. page_num = int(page_num)
  488. page_size = int(page_size)
  489. except Exception as e:
  490. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  491. response = app.task_manager.get_test_task_conversations(int(task_id), page_num, page_size)
  492. return wrap_response(200, data=response)
  493. @app.route("/api/createTestTask", methods=["POST"])
  494. def create_test_task():
  495. """
  496. 创建单元测试任务
  497. :return:
  498. """
  499. req_data = request.json
  500. agent_id = req_data.get('agentId', None)
  501. model_id = req_data.get('modelId', None)
  502. if not agent_id:
  503. return wrap_response(404, msg='agent id is required')
  504. if not model_id:
  505. return wrap_response(404, msg='model id is required')
  506. app.task_manager.create_task(agent_id, model_id)
  507. return wrap_response(200)
  508. @app.route("/api/stopTestTask", methods=["POST"])
  509. def stop_test_task():
  510. """
  511. 停止单元测试任务
  512. :return:
  513. """
  514. req_data = request.json
  515. task_id = req_data.get('taskId', None)
  516. if not task_id:
  517. return wrap_response(400, msg='task id is required')
  518. task = app.task_manager.get_task(task_id)
  519. if task.status not in (TestTaskStatus.NOT_STARTED.value, TestTaskStatus.IN_PROGRESS.value):
  520. return wrap_response(400, msg='task status is invalid')
  521. app.task_manager.cancel_task(task_id)
  522. return wrap_response(200)
  523. @app.route("/api/resumeTestTask", methods=["POST"])
  524. def resume_test_task():
  525. """
  526. 恢复停止的单元测试任务
  527. :return:
  528. """
  529. req_data = request.json
  530. task_id = req_data.get('taskId', None)
  531. if not task_id:
  532. return wrap_response(400, msg='task id is required')
  533. task = app.task_manager.get_task(task_id)
  534. if task.status != TestTaskStatus.CANCELLED.value:
  535. return wrap_response(400, msg='task status is invalid')
  536. app.task_manager.resume_task(task_id)
  537. return wrap_response(200)
  538. @app.errorhandler(werkzeug.exceptions.BadRequest)
  539. def handle_bad_request(e):
  540. logger.error(e)
  541. return wrap_response(400, msg='Bad Request: {}'.format(e.description))
  542. if __name__ == '__main__':
  543. parser = ArgumentParser()
  544. parser.add_argument('--prod', action='store_true')
  545. parser.add_argument('--host', default='127.0.0.1')
  546. parser.add_argument('--port', type=int, default=8083)
  547. parser.add_argument('--log-level', default='INFO')
  548. args = parser.parse_args()
  549. config = configs.get()
  550. logging_level = logging.getLevelName(args.log_level)
  551. logging_service.setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
  552. # set db config
  553. user_db_config = config['storage']['user']
  554. staff_db_config = config['storage']['staff']
  555. agent_state_db_config = config['storage']['agent_state']
  556. chat_history_db_config = config['storage']['chat_history']
  557. # init user manager
  558. user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
  559. app.user_manager = user_manager
  560. # init session manager
  561. session_manager = MySQLSessionManager(
  562. db_config=user_db_config['mysql'],
  563. staff_table=staff_db_config['table'],
  564. user_table=user_db_config['table'],
  565. agent_state_table=agent_state_db_config['table'],
  566. chat_history_table=chat_history_db_config['table']
  567. )
  568. app.session_manager = session_manager
  569. agent_db_engine = create_sql_engine(config['storage']['agent_state']['mysql'])
  570. app.session_maker = sessionmaker(bind=agent_db_engine)
  571. task_manager = TaskManager(session_maker=sessionmaker(bind=agent_db_engine))
  572. app.task_manager = task_manager
  573. wecom_db_config = config['storage']['user_relation']
  574. user_relation_manager = MySQLUserRelationManager(
  575. user_db_config['mysql'], wecom_db_config['mysql'],
  576. config['storage']['staff']['table'],
  577. user_db_config['table'],
  578. wecom_db_config['table']['staff'],
  579. wecom_db_config['table']['relation'],
  580. wecom_db_config['table']['user']
  581. )
  582. app.user_relation_manager = user_relation_manager
  583. app.history_dialogue_service = HistoryDialogueService(
  584. config['storage']['history_dialogue']['api_base_url']
  585. )
  586. app.run(debug=not args.prod, host=args.host, port=args.port)