api_server.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import json
  5. import time
  6. import logging
  7. from argparse import ArgumentParser
  8. import werkzeug.exceptions
  9. from flask import Flask, request, jsonify
  10. from sqlalchemy.orm import sessionmaker
  11. from pqai_agent import configs
  12. from pqai_agent import chat_service, prompt_templates
  13. from pqai_agent.logging import logger, setup_root_logger
  14. from pqai_agent.toolkit import global_tool_map
  15. from pqai_agent import logging_service, chat_service, prompt_templates
  16. from pqai_agent.data_models.agent_configuration import AgentConfiguration
  17. from pqai_agent.data_models.service_module import ServiceModule
  18. from pqai_agent.history_dialogue_service import HistoryDialogueService
  19. from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
  20. from pqai_agent.utils.db_utils import create_ai_agent_db_engine
  21. from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
  22. from pqai_agent_server.const import AgentApiConst
  23. from pqai_agent_server.const.status_enum import TestTaskStatus
  24. from pqai_agent_server.const.type_enum import EvaluateType
  25. from pqai_agent_server.dataset_service import DatasetService
  26. from pqai_agent_server.models import MySQLSessionManager
  27. import pqai_agent_server.utils
  28. from pqai_agent_server.utils import wrap_response
  29. from pqai_agent_server.task_server import TaskManager
  30. from pqai_agent_server.utils import (
  31. run_extractor_prompt,
  32. run_chat_prompt,
  33. run_response_type_prompt,
  34. )
  35. from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
  36. app = Flask('agent_api_server')
  37. const = AgentApiConst()
  38. @app.route('/api/listStaffs', methods=['GET'])
  39. def list_staffs():
  40. staff_data = app.user_relation_manager.list_staffs()
  41. return wrap_response(200, data=staff_data)
  42. @app.route('/api/getStaffProfile', methods=['GET'])
  43. def get_staff_profile():
  44. staff_id = request.args['staff_id']
  45. profile = app.user_manager.get_staff_profile(staff_id)
  46. if not profile:
  47. return wrap_response(404, msg='staff not found')
  48. else:
  49. return wrap_response(200, data=profile)
  50. @app.route('/api/getUserProfile', methods=['GET'])
  51. def get_user_profile():
  52. user_id = request.args['user_id']
  53. profile = app.user_manager.get_user_profile(user_id)
  54. if not profile:
  55. resp = {
  56. 'code': 404,
  57. 'msg': 'user not found'
  58. }
  59. else:
  60. resp = {
  61. 'code': 200,
  62. 'msg': 'success',
  63. 'data': profile
  64. }
  65. return jsonify(resp)
  66. @app.route('/api/listUsers', methods=['GET'])
  67. def list_users():
  68. user_name = request.args.get('user_name', None)
  69. user_union_id = request.args.get('user_union_id', None)
  70. if not user_name and not user_union_id:
  71. resp = {
  72. 'code': 400,
  73. 'msg': 'user_name or user_union_id is required'
  74. }
  75. return jsonify(resp)
  76. data = app.user_manager.list_users(user_name=user_name, user_union_id=user_union_id)
  77. return jsonify({'code': 200, 'data': data})
  78. @app.route('/api/getDialogueHistory', methods=['GET'])
  79. def get_dialogue_history():
  80. staff_id = request.args['staff_id']
  81. user_id = request.args['user_id']
  82. recent_minutes = int(request.args.get('recent_minutes', 1440))
  83. dialogue_history = app.history_dialogue_service.get_dialogue_history(staff_id, user_id, recent_minutes)
  84. return jsonify({'code': 200, 'data': dialogue_history})
  85. @app.route('/api/listModels', methods=['GET'])
  86. def list_models():
  87. models = {
  88. "deepseek-chat": chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
  89. "gpt-4o": chat_service.OPENAI_MODEL_GPT_4o,
  90. "gpt-4o-mini": chat_service.OPENAI_MODEL_GPT_4o_mini,
  91. "doubao-pro-32k": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  92. "doubao-pro-1.5": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
  93. "doubao-1.5-vision-pro": chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
  94. }
  95. ret_data = [
  96. {
  97. 'model_type': 'openai_compatible',
  98. 'model_name': model_name,
  99. 'display_name': model_display_name
  100. }
  101. for model_display_name, model_name in models.items()
  102. ]
  103. return wrap_response(200, data=ret_data)
  104. @app.route('/api/listScenes', methods=['GET'])
  105. def list_scenes():
  106. scenes = [
  107. {'scene': 'greeting', 'display_name': '问候'},
  108. {'scene': 'chitchat', 'display_name': '闲聊'},
  109. {'scene': 'profile_extractor', 'display_name': '画像提取'},
  110. {'scene': 'response_type_detector', 'display_name': '回复模态判断'},
  111. {'scene': 'custom_debugging', 'display_name': '自定义调试场景'}
  112. ]
  113. return wrap_response(200, data=scenes)
  114. @app.route('/api/getBasePrompt', methods=['GET'])
  115. def get_base_prompt():
  116. scene = request.args['scene']
  117. prompt_map = {
  118. 'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
  119. 'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
  120. 'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT_V2,
  121. 'response_type_detector': prompt_templates.RESPONSE_TYPE_DETECT_PROMPT,
  122. 'custom_debugging': '',
  123. }
  124. model_map = {
  125. 'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  126. 'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  127. 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
  128. 'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
  129. 'custom_debugging': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH
  130. }
  131. if scene not in prompt_map:
  132. return wrap_response(404, msg='scene not found')
  133. data = {
  134. 'model_name': model_map[scene],
  135. 'content': prompt_map[scene]
  136. }
  137. return wrap_response(200, data=data)
  138. @app.route('/api/runPrompt', methods=['POST'])
  139. def run_prompt():
  140. try:
  141. req_data = request.json
  142. logger.debug(req_data)
  143. scene = req_data['scene']
  144. if scene == 'profile_extractor':
  145. response = run_extractor_prompt(req_data)
  146. return wrap_response(200, data=response)
  147. elif scene == 'response_type_detector':
  148. response = run_response_type_prompt(req_data)
  149. return wrap_response(200, data=response.choices[0].message.content)
  150. else:
  151. response = run_chat_prompt(req_data)
  152. return wrap_response(200, data=response.choices[0].message.content)
  153. except Exception as e:
  154. logger.error(e)
  155. return wrap_response(500, msg='Error: {}'.format(e))
  156. @app.route('/api/formatForPrompt', methods=['POST'])
  157. def format_data_for_prompt():
  158. try:
  159. req_data = request.json
  160. content = req_data['content']
  161. format_type = req_data['format_type']
  162. if format_type == 'staff_profile':
  163. if not isinstance(content, dict):
  164. return wrap_response(400, msg='staff_profile should be a dict')
  165. response = format_agent_profile(content)
  166. elif format_type == 'user_profile':
  167. if not isinstance(content, dict):
  168. return wrap_response(400, msg='user_profile should be a dict')
  169. response = format_user_profile(content)
  170. elif format_type == 'dialogue':
  171. if not isinstance(content, list):
  172. return wrap_response(400, msg='dialogue should be a list')
  173. from pqai_agent_server.utils.prompt_util import format_dialogue_history
  174. response = format_dialogue_history(content)
  175. else:
  176. return wrap_response(400, msg='Invalid format_type')
  177. return wrap_response(200, data=response)
  178. except Exception as e:
  179. logger.error(e)
  180. return wrap_response(500, msg='Error: {}'.format(e))
  181. @app.route("/api/healthCheck", methods=["GET"])
  182. def health_check():
  183. return wrap_response(200, msg="OK")
  184. @app.route("/api/getStaffSessionSummary", methods=["GET"])
  185. def get_staff_session_summary():
  186. staff_id = request.args.get("staff_id")
  187. status = request.args.get("status", const.DEFAULT_STAFF_STATUS)
  188. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  189. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  190. # check params
  191. try:
  192. page_id = int(page_id)
  193. page_size = int(page_size)
  194. status = int(status)
  195. except Exception as e:
  196. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  197. staff_session_summary = app.session_manager.get_staff_sessions_summary(
  198. staff_id, page_id, page_size, status
  199. )
  200. if not staff_session_summary:
  201. return wrap_response(404, msg="staff not found")
  202. else:
  203. return wrap_response(200, data=staff_session_summary)
  204. @app.route("/api/getStaffSessionList", methods=["GET"])
  205. def get_staff_session_list():
  206. staff_id = request.args.get("staff_id")
  207. if not staff_id:
  208. return wrap_response(404, msg="staff_id is required")
  209. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  210. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  211. # check params
  212. try:
  213. page_id = int(page_id)
  214. page_size = int(page_size)
  215. except Exception as e:
  216. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  217. staff_session_list = app.session_manager.get_staff_session_list(staff_id, page_id, page_size)
  218. if not staff_session_list:
  219. return wrap_response(404, msg="staff not found")
  220. return wrap_response(200, data=staff_session_list)
  221. @app.route("/api/getStaffList", methods=["GET"])
  222. def get_staff_list():
  223. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  224. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  225. # check params
  226. try:
  227. page_id = int(page_id)
  228. page_size = int(page_size)
  229. except Exception as e:
  230. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  231. staff_list = app.user_manager.get_staff_list(page_id, page_size)
  232. if not staff_list:
  233. return wrap_response(404, msg="staff not found")
  234. return wrap_response(200, data=staff_list)
  235. @app.route("/api/getConversationList", methods=["GET"])
  236. def get_conversation_list():
  237. """
  238. 获取staff && user 私聊对话列表
  239. :return:
  240. """
  241. staff_id = request.args.get("staff_id")
  242. user_id = request.args.get("user_id")
  243. if not staff_id or not user_id:
  244. return wrap_response(404, msg="staff_id and user_id are required")
  245. page = request.args.get("page")
  246. response = app.session_manager.get_conversation_list(staff_id, user_id, page, const.DEFAULT_CONVERSATION_SIZE)
  247. return wrap_response(200, data=response)
  248. @app.route("/api/sendMessage", methods=["POST"])
  249. def send_message():
  250. return wrap_response(200, msg="暂不实现功能")
  251. @app.route("/api/quitHumanIntervention", methods=["POST"])
  252. def quit_human_intervention():
  253. """
  254. 退出人工介入状态
  255. :return:
  256. """
  257. req_data = request.json
  258. staff_id = req_data["staff_id"]
  259. user_id = req_data["user_id"]
  260. if not user_id or not staff_id:
  261. return wrap_response(404, msg="user_id and staff_id are required")
  262. if pqai_agent_server.utils.common.quit_human_intervention(user_id, staff_id):
  263. return wrap_response(200, msg="success")
  264. else:
  265. return wrap_response(500, msg="error")
  266. @app.route("/api/enterHumanIntervention", methods=["POST"])
  267. def enter_human_intervention():
  268. """
  269. 进入人工介入状态
  270. :return:
  271. """
  272. req_data = request.json
  273. staff_id = req_data["staff_id"]
  274. user_id = req_data["user_id"]
  275. if not user_id or not staff_id:
  276. return wrap_response(404, msg="user_id and staff_id are required")
  277. if pqai_agent_server.utils.common.enter_human_intervention(user_id, staff_id):
  278. return wrap_response(200, msg="success")
  279. else:
  280. return wrap_response(500, msg="error")
  281. ## Agent管理接口
  282. @app.route("/api/getNativeAgentList", methods=["GET"])
  283. def get_native_agent_list():
  284. """
  285. 获取所有的Agent列表
  286. :return:
  287. """
  288. page = request.args.get('page', 1)
  289. page_size = request.args.get('page_size', 50)
  290. create_user = request.args.get('create_user', None)
  291. update_user = request.args.get('update_user', None)
  292. offset = (int(page) - 1) * int(page_size)
  293. with app.session_maker() as session:
  294. query = session.query(AgentConfiguration) \
  295. .filter(AgentConfiguration.is_delete == 0)
  296. if create_user:
  297. query = query.filter(AgentConfiguration.create_user == create_user)
  298. if update_user:
  299. query = query.filter(AgentConfiguration.update_user == update_user)
  300. total = query.count()
  301. query = query.offset(offset).limit(int(page_size))
  302. data = query.all()
  303. ret_data = {
  304. 'total': total,
  305. 'agent_list': [
  306. {
  307. 'id': agent.id,
  308. 'name': agent.name,
  309. 'display_name': agent.display_name,
  310. 'type': agent.type,
  311. 'execution_model': agent.execution_model,
  312. 'create_user': agent.create_user,
  313. 'update_user': agent.update_user,
  314. 'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  315. 'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
  316. }
  317. for agent in data
  318. ]
  319. }
  320. return wrap_response(200, data=ret_data)
  321. @app.route("/api/getNativeAgentConfiguration", methods=["GET"])
  322. def get_native_agent_configuration():
  323. """
  324. 获取指定Agent的配置
  325. :return:
  326. """
  327. agent_id = request.args.get('agent_id')
  328. if not agent_id:
  329. return wrap_response(404, msg='agent_id is required')
  330. with app.session_maker() as session:
  331. agent = session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).first()
  332. if not agent:
  333. return wrap_response(404, msg='Agent not found')
  334. data = {
  335. 'id': agent.id,
  336. 'name': agent.name,
  337. 'display_name': agent.display_name,
  338. 'type': agent.type,
  339. 'execution_model': agent.execution_model,
  340. 'system_prompt': agent.system_prompt,
  341. 'task_prompt': agent.task_prompt,
  342. 'tools': json.loads(agent.tools),
  343. 'sub_agents': json.loads(agent.sub_agents),
  344. 'extra_params': json.loads(agent.extra_params),
  345. 'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  346. 'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
  347. }
  348. return wrap_response(200, data=data)
  349. @app.route("/api/saveNativeAgentConfiguration", methods=["POST"])
  350. def save_native_agent_configuration():
  351. """
  352. 保存Agent配置
  353. :return:
  354. """
  355. req_data = request.json
  356. agent_id = req_data.get('agent_id', None)
  357. name = req_data.get('name')
  358. display_name = req_data.get('display_name', None)
  359. type_ = req_data.get('type', 0)
  360. execution_model = req_data.get('execution_model', None)
  361. system_prompt = req_data.get('system_prompt', None)
  362. task_prompt = req_data.get('task_prompt', None)
  363. tools = json.dumps(req_data.get('tools', []))
  364. sub_agents = json.dumps(req_data.get('sub_agents', []))
  365. extra_params = req_data.get('extra_params', {})
  366. operate_user = req_data.get('user', None)
  367. if isinstance(extra_params, dict):
  368. extra_params = json.dumps(extra_params)
  369. elif isinstance(extra_params, str):
  370. try:
  371. json.loads(extra_params)
  372. except json.JSONDecodeError:
  373. return wrap_response(400, msg='extra_params should be a valid JSON object or string')
  374. if not extra_params:
  375. extra_params = '{}'
  376. if not name:
  377. return wrap_response(400, msg='name is required')
  378. with app.session_maker() as session:
  379. if agent_id:
  380. agent_id = int(agent_id)
  381. agent = session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).first()
  382. if not agent:
  383. return wrap_response(404, msg='Agent not found')
  384. agent.name = name
  385. agent.display_name = display_name
  386. agent.type = type_
  387. agent.execution_model = execution_model
  388. agent.system_prompt = system_prompt
  389. agent.task_prompt = task_prompt
  390. agent.tools = tools
  391. agent.sub_agents = sub_agents
  392. agent.extra_params = extra_params
  393. agent.update_user = operate_user
  394. else:
  395. agent = AgentConfiguration(
  396. name=name,
  397. display_name=display_name,
  398. type=type_,
  399. execution_model=execution_model,
  400. system_prompt=system_prompt,
  401. task_prompt=task_prompt,
  402. tools=tools,
  403. sub_agents=sub_agents,
  404. extra_params=extra_params,
  405. create_user=operate_user,
  406. update_user=operate_user
  407. )
  408. session.add(agent)
  409. session.commit()
  410. return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
  411. @app.route("/api/deleteNativeAgentConfiguration", methods=["POST"])
  412. def delete_native_agent_configuration():
  413. """
  414. 删除指定Agent配置(软删除,设置is_delete=1)
  415. :return:
  416. """
  417. req_data = request.json
  418. agent_id = req_data.get('agent_id', None)
  419. if not agent_id:
  420. return wrap_response(400, msg='agent_id is required')
  421. try:
  422. agent_id = int(agent_id)
  423. except ValueError:
  424. return wrap_response(400, msg='agent_id must be an integer')
  425. with app.session_maker() as session:
  426. agent = session.query(AgentConfiguration).filter(
  427. AgentConfiguration.id == agent_id,
  428. AgentConfiguration.is_delete == 0
  429. ).first()
  430. if not agent:
  431. return wrap_response(404, msg='Agent not found')
  432. agent.is_delete = 1
  433. session.commit()
  434. return wrap_response(200, msg='Agent configuration deleted successfully')
  435. @app.route("/api/getModuleList", methods=["GET"])
  436. def get_module_list():
  437. """
  438. 获取所有的模块列表
  439. :return:
  440. """
  441. with app.session_maker() as session:
  442. query = session.query(ServiceModule) \
  443. .filter(ServiceModule.is_delete == 0)
  444. data = query.all()
  445. ret_data = [
  446. {
  447. 'id': module.id,
  448. 'name': module.name,
  449. 'display_name': module.display_name,
  450. 'default_agent_type': module.default_agent_type,
  451. 'default_agent_id': module.default_agent_id,
  452. 'create_time': module.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  453. 'update_time': module.update_time.strftime('%Y-%m-%d %H:%M:%S')
  454. }
  455. for module in data
  456. ]
  457. return wrap_response(200, data=ret_data)
  458. @app.route("/api/getModuleConfiguration", methods=["GET"])
  459. def get_module_configuration():
  460. """
  461. 获取指定模块的配置
  462. :return:
  463. """
  464. module_id = request.args.get('module_id')
  465. if not module_id:
  466. return wrap_response(404, msg='module_id is required')
  467. with app.session_maker() as session:
  468. module = session.query(ServiceModule).filter(ServiceModule.id == module_id).first()
  469. if not module:
  470. return wrap_response(404, msg='Module not found')
  471. data = {
  472. 'id': module.id,
  473. 'name': module.name,
  474. 'display_name': module.display_name,
  475. 'default_agent_type': module.default_agent_type,
  476. 'default_agent_id': module.default_agent_id,
  477. 'create_time': module.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  478. 'updated_time': module.updated_time.strftime('%Y-%m-%d %H:%M:%S')
  479. }
  480. return wrap_response(200, data=data)
  481. @app.route("/api/saveModuleConfiguration", methods=["POST"])
  482. def save_module_configuration():
  483. """
  484. 保存模块配置
  485. :return:
  486. """
  487. req_data = request.json
  488. module_id = req_data.get('module_id', None)
  489. name = req_data.get('name')
  490. display_name = req_data.get('display_name', None)
  491. default_agent_type = req_data.get('default_agent_type', 0)
  492. default_agent_id = req_data.get('default_agent_id', None)
  493. if not name:
  494. return wrap_response(400, msg='name is required')
  495. with app.session_maker() as session:
  496. if module_id:
  497. module_id = int(module_id)
  498. module = session.query(ServiceModule).filter(ServiceModule.id == module_id).first()
  499. if not module:
  500. return wrap_response(404, msg='Module not found')
  501. module.name = name
  502. module.display_name = display_name
  503. module.default_agent_type = default_agent_type
  504. module.default_agent_id = default_agent_id
  505. else:
  506. module = ServiceModule(
  507. name=name,
  508. display_name=display_name,
  509. default_agent_type=default_agent_type,
  510. default_agent_id=default_agent_id
  511. )
  512. session.add(module)
  513. session.commit()
  514. return wrap_response(200, msg='Module configuration saved successfully', data={'id': module.id})
  515. @app.route("/api/getTestTaskList", methods=["GET"])
  516. def get_test_task_list():
  517. """
  518. 获取单元测试任务列表
  519. :return:
  520. """
  521. page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
  522. page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
  523. try:
  524. page_num = int(page_num)
  525. page_size = int(page_size)
  526. except Exception as e:
  527. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  528. response = app.task_manager.get_test_task_list(page_num, page_size)
  529. return wrap_response(200, data=response)
  530. @app.route("/api/getTestTaskConversations", methods=["GET"])
  531. def get_test_task_conversations():
  532. """
  533. 获取单元测试对话任务列表
  534. :return:
  535. """
  536. task_id = request.args.get("taskId", None)
  537. if not task_id:
  538. return wrap_response(404, msg='task_id is required')
  539. page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
  540. page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
  541. try:
  542. page_num = int(page_num)
  543. page_size = int(page_size)
  544. except Exception as e:
  545. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  546. response = app.task_manager.get_test_task_conversations(int(task_id), page_num, page_size)
  547. return wrap_response(200, data=response)
  548. @app.route("/api/createTestTask", methods=["POST"])
  549. def create_test_task():
  550. """
  551. 创建单元测试任务
  552. :return:
  553. """
  554. req_data = request.json
  555. agent_id = req_data.get('agentId', None)
  556. module_id = req_data.get('moduleId', None)
  557. evaluate_type = req_data.get('evaluateType', None)
  558. if not agent_id:
  559. return wrap_response(404, msg='agent id is required')
  560. if not module_id:
  561. return wrap_response(404, msg='module id is required')
  562. if not evaluate_type:
  563. return wrap_response(404, msg='evaluate_type id is required')
  564. app.task_manager.create_task(agent_id, module_id, evaluate_type)
  565. return wrap_response(200)
  566. @app.route("/api/stopTestTask", methods=["POST"])
  567. def stop_test_task():
  568. """
  569. 停止单元测试任务
  570. :return:
  571. """
  572. req_data = request.json
  573. task_id = req_data.get('taskId', None)
  574. if not task_id:
  575. return wrap_response(400, msg='task id is required')
  576. task = app.task_manager.get_task(task_id)
  577. if task.status not in (TestTaskStatus.NOT_STARTED.value, TestTaskStatus.IN_PROGRESS.value):
  578. return wrap_response(400, msg='task status is invalid')
  579. app.task_manager.cancel_task(task_id)
  580. return wrap_response(200)
  581. @app.route("/api/resumeTestTask", methods=["POST"])
  582. def resume_test_task():
  583. """
  584. 恢复停止的单元测试任务
  585. :return:
  586. """
  587. req_data = request.json
  588. task_id = req_data.get('taskId', None)
  589. if not task_id:
  590. return wrap_response(400, msg='task id is required')
  591. task = app.task_manager.get_task(task_id)
  592. if task.status != TestTaskStatus.CANCELLED.value:
  593. return wrap_response(400, msg='task status is invalid')
  594. app.task_manager.resume_task(task_id)
  595. return wrap_response(200)
  596. @app.route("/api/getEvaluateType", methods=["GET"])
  597. def get_evaluate_type():
  598. """
  599. 获取评估类型
  600. :return:
  601. """
  602. name_desc_list = [
  603. {
  604. "type": item.value,
  605. "desc": item.description
  606. }
  607. for item in EvaluateType]
  608. return wrap_response(code=200, data=name_desc_list)
  609. @app.route("/api/getDatasetList", methods=["GET"])
  610. def get_dataset_list():
  611. """
  612. 获取数据集列表
  613. :return:
  614. """
  615. page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
  616. page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
  617. try:
  618. page_num = int(page_num)
  619. page_size = int(page_size)
  620. except Exception as e:
  621. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  622. response = app.dataset_service.get_dataset_list(page_num, page_size)
  623. return wrap_response(200, data=response)
  624. @app.route("/api/getConversationDataList", methods=["GET"])
  625. def get_conversation_data_list():
  626. """
  627. 获取对话列表
  628. :return:
  629. """
  630. dataset_id = request.args.get("datasetId", None)
  631. if not dataset_id:
  632. return wrap_response(404, msg='dataset_id is required')
  633. page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
  634. page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
  635. try:
  636. page_num = int(page_num)
  637. page_size = int(page_size)
  638. except Exception as e:
  639. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  640. response = app.dataset_service.get_conversation_data_list(int(dataset_id), page_num, page_size)
  641. return wrap_response(200, data=response)
  642. @app.route("/api/getToolList", methods=["GET"])
  643. def get_tool_list():
  644. """
  645. 获取所有的工具列表
  646. :return:
  647. """
  648. tools = []
  649. for tool_name, tool in global_tool_map.items():
  650. tools.append({
  651. 'name': tool_name,
  652. 'description': tool.get_function_description(),
  653. 'parameters': tool.parameters if hasattr(tool, 'parameters') else {}
  654. })
  655. return wrap_response(200, data=tools)
  656. @app.route("/api/getModuleAgentTypes", methods=["GET"])
  657. def get_agent_types():
  658. """
  659. 获取所有的Agent类型
  660. :return:
  661. """
  662. agent_types = [
  663. {'type': 0, 'display_name': '原生'},
  664. {'type': 1, 'display_name': 'Coze'}
  665. ]
  666. return wrap_response(200, data=agent_types)
  667. @app.errorhandler(werkzeug.exceptions.BadRequest)
  668. def handle_bad_request(e):
  669. logger.error(e)
  670. return wrap_response(400, msg='Bad Request: {}'.format(e.description))
  671. if __name__ == '__main__':
  672. parser = ArgumentParser()
  673. parser.add_argument('--prod', action='store_true')
  674. parser.add_argument('--host', default='127.0.0.1')
  675. parser.add_argument('--port', type=int, default=8083)
  676. parser.add_argument('--log-level', default='INFO')
  677. args = parser.parse_args()
  678. config = configs.get()
  679. logging_level = logging.getLevelName(args.log_level)
  680. setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
  681. # set db config
  682. agent_db_config = config['database']['ai_agent']
  683. growth_db_config = config['database']['growth']
  684. user_db_config = config['storage']['user']
  685. staff_db_config = config['storage']['staff']
  686. agent_state_db_config = config['storage']['agent_state']
  687. chat_history_db_config = config['storage']['chat_history']
  688. # init user manager
  689. user_manager = MySQLUserManager(agent_db_config, user_db_config['table'], staff_db_config['table'])
  690. app.user_manager = user_manager
  691. # init session manager
  692. session_manager = MySQLSessionManager(
  693. db_config=agent_db_config,
  694. staff_table=staff_db_config['table'],
  695. user_table=user_db_config['table'],
  696. agent_state_table=agent_state_db_config['table'],
  697. chat_history_table=chat_history_db_config['table']
  698. )
  699. app.session_manager = session_manager
  700. agent_db_engine = create_ai_agent_db_engine()
  701. app.session_maker = sessionmaker(bind=agent_db_engine)
  702. dataset_service = DatasetService(session_maker=sessionmaker(bind=agent_db_engine))
  703. app.dataset_service = dataset_service
  704. task_manager = TaskManager(session_maker=sessionmaker(bind=agent_db_engine), dataset_service=dataset_service)
  705. app.task_manager = task_manager
  706. task_manager.recover_tasks()
  707. wecom_db_config = config['storage']['user_relation']
  708. user_relation_manager = MySQLUserRelationManager(
  709. agent_db_config, growth_db_config,
  710. config['storage']['staff']['table'],
  711. user_db_config['table'],
  712. wecom_db_config['table']['staff'],
  713. wecom_db_config['table']['relation'],
  714. wecom_db_config['table']['user']
  715. )
  716. app.user_relation_manager = user_relation_manager
  717. app.history_dialogue_service = HistoryDialogueService(
  718. config['storage']['history_dialogue']['api_base_url']
  719. )
  720. app.run(debug=not args.prod, host=args.host, port=args.port)