api_server.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import json
  5. import logging
  6. from argparse import ArgumentParser
  7. import werkzeug.exceptions
  8. from flask import Flask, request, jsonify
  9. from sqlalchemy.orm import sessionmaker
  10. import pqai_agent_server.utils
  11. from pqai_agent import chat_service, prompt_templates
  12. from pqai_agent import configs
  13. from pqai_agent.data_models.agent_configuration import AgentConfiguration
  14. from pqai_agent.data_models.service_module import ServiceModule
  15. from pqai_agent.history_dialogue_service import HistoryDialogueService
  16. from pqai_agent.logging import logger, setup_root_logger
  17. from pqai_agent.toolkit import global_tool_map
  18. from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
  19. from pqai_agent.utils.db_utils import create_ai_agent_db_engine
  20. from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
  21. from pqai_agent_server.agent_task_server import AgentTaskManager
  22. from pqai_agent_server.const import AgentApiConst
  23. from pqai_agent_server.const.status_enum import TestTaskStatus
  24. from pqai_agent_server.const.type_enum import EvaluateType
  25. from pqai_agent_server.dataset_service import DatasetService
  26. from pqai_agent_server.models import MySQLSessionManager
  27. from pqai_agent_server.task_server import TaskManager
  28. from pqai_agent_server.utils import (
  29. run_extractor_prompt,
  30. run_chat_prompt,
  31. run_response_type_prompt,
  32. )
  33. from pqai_agent_server.utils import wrap_response
  34. app = Flask('agent_api_server')
  35. const = AgentApiConst()
  36. @app.route('/api/listStaffs', methods=['GET'])
  37. def list_staffs():
  38. staff_data = app.user_relation_manager.list_staffs()
  39. return wrap_response(200, data=staff_data)
  40. @app.route('/api/getStaffProfile', methods=['GET'])
  41. def get_staff_profile():
  42. staff_id = request.args['staff_id']
  43. profile = app.user_manager.get_staff_profile(staff_id)
  44. if not profile:
  45. return wrap_response(404, msg='staff not found')
  46. else:
  47. return wrap_response(200, data=profile)
  48. @app.route('/api/getUserProfile', methods=['GET'])
  49. def get_user_profile():
  50. user_id = request.args['user_id']
  51. profile = app.user_manager.get_user_profile(user_id)
  52. if not profile:
  53. resp = {
  54. 'code': 404,
  55. 'msg': 'user not found'
  56. }
  57. else:
  58. resp = {
  59. 'code': 200,
  60. 'msg': 'success',
  61. 'data': profile
  62. }
  63. return jsonify(resp)
  64. @app.route('/api/listUsers', methods=['GET'])
  65. def list_users():
  66. user_name = request.args.get('user_name', None)
  67. user_union_id = request.args.get('user_union_id', None)
  68. if not user_name and not user_union_id:
  69. resp = {
  70. 'code': 400,
  71. 'msg': 'user_name or user_union_id is required'
  72. }
  73. return jsonify(resp)
  74. data = app.user_manager.list_users(user_name=user_name, user_union_id=user_union_id)
  75. return jsonify({'code': 200, 'data': data})
  76. @app.route('/api/getDialogueHistory', methods=['GET'])
  77. def get_dialogue_history():
  78. staff_id = request.args['staff_id']
  79. user_id = request.args['user_id']
  80. recent_minutes = int(request.args.get('recent_minutes', 1440))
  81. dialogue_history = app.history_dialogue_service.get_dialogue_history(staff_id, user_id, recent_minutes)
  82. return jsonify({'code': 200, 'data': dialogue_history})
  83. @app.route('/api/listModels', methods=['GET'])
  84. def list_models():
  85. models = {
  86. "deepseek-chat": chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
  87. "gpt-4o": chat_service.OPENAI_MODEL_GPT_4o,
  88. "gpt-4o-mini": chat_service.OPENAI_MODEL_GPT_4o_mini,
  89. "doubao-pro-32k": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  90. "doubao-pro-1.5": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
  91. "doubao-1.5-vision-pro": chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
  92. }
  93. ret_data = [
  94. {
  95. 'model_type': 'openai_compatible',
  96. 'model_name': model_name,
  97. 'display_name': model_display_name
  98. }
  99. for model_display_name, model_name in models.items()
  100. ]
  101. return wrap_response(200, data=ret_data)
  102. @app.route('/api/listScenes', methods=['GET'])
  103. def list_scenes():
  104. scenes = [
  105. {'scene': 'greeting', 'display_name': '问候'},
  106. {'scene': 'chitchat', 'display_name': '闲聊'},
  107. {'scene': 'profile_extractor', 'display_name': '画像提取'},
  108. {'scene': 'response_type_detector', 'display_name': '回复模态判断'},
  109. {'scene': 'custom_debugging', 'display_name': '自定义调试场景'}
  110. ]
  111. return wrap_response(200, data=scenes)
  112. @app.route('/api/getBasePrompt', methods=['GET'])
  113. def get_base_prompt():
  114. scene = request.args['scene']
  115. prompt_map = {
  116. 'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
  117. 'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
  118. 'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT_V2,
  119. 'response_type_detector': prompt_templates.RESPONSE_TYPE_DETECT_PROMPT,
  120. 'custom_debugging': '',
  121. }
  122. model_map = {
  123. 'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  124. 'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  125. 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
  126. 'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
  127. 'custom_debugging': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH
  128. }
  129. if scene not in prompt_map:
  130. return wrap_response(404, msg='scene not found')
  131. data = {
  132. 'model_name': model_map[scene],
  133. 'content': prompt_map[scene]
  134. }
  135. return wrap_response(200, data=data)
  136. @app.route('/api/runPrompt', methods=['POST'])
  137. def run_prompt():
  138. try:
  139. req_data = request.json
  140. logger.debug(req_data)
  141. scene = req_data['scene']
  142. if scene == 'profile_extractor':
  143. response = run_extractor_prompt(req_data)
  144. return wrap_response(200, data=response)
  145. elif scene == 'response_type_detector':
  146. response = run_response_type_prompt(req_data)
  147. return wrap_response(200, data=response.choices[0].message.content)
  148. else:
  149. response = run_chat_prompt(req_data)
  150. return wrap_response(200, data=response.choices[0].message.content)
  151. except Exception as e:
  152. logger.error(e)
  153. return wrap_response(500, msg='Error: {}'.format(e))
  154. @app.route('/api/formatForPrompt', methods=['POST'])
  155. def format_data_for_prompt():
  156. try:
  157. req_data = request.json
  158. content = req_data['content']
  159. format_type = req_data['format_type']
  160. if format_type == 'staff_profile':
  161. if not isinstance(content, dict):
  162. return wrap_response(400, msg='staff_profile should be a dict')
  163. response = format_agent_profile(content)
  164. elif format_type == 'user_profile':
  165. if not isinstance(content, dict):
  166. return wrap_response(400, msg='user_profile should be a dict')
  167. response = format_user_profile(content)
  168. elif format_type == 'dialogue':
  169. if not isinstance(content, list):
  170. return wrap_response(400, msg='dialogue should be a list')
  171. from pqai_agent_server.utils.prompt_util import format_dialogue_history
  172. response = format_dialogue_history(content)
  173. else:
  174. return wrap_response(400, msg='Invalid format_type')
  175. return wrap_response(200, data=response)
  176. except Exception as e:
  177. logger.error(e)
  178. return wrap_response(500, msg='Error: {}'.format(e))
  179. @app.route("/api/healthCheck", methods=["GET"])
  180. def health_check():
  181. return wrap_response(200, msg="OK")
  182. @app.route("/api/getStaffSessionSummary", methods=["GET"])
  183. def get_staff_session_summary():
  184. staff_id = request.args.get("staff_id")
  185. status = request.args.get("status", const.DEFAULT_STAFF_STATUS)
  186. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  187. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  188. # check params
  189. try:
  190. page_id = int(page_id)
  191. page_size = int(page_size)
  192. status = int(status)
  193. except Exception as e:
  194. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  195. staff_session_summary = app.session_manager.get_staff_sessions_summary(
  196. staff_id, page_id, page_size, status
  197. )
  198. if not staff_session_summary:
  199. return wrap_response(404, msg="staff not found")
  200. else:
  201. return wrap_response(200, data=staff_session_summary)
  202. @app.route("/api/getStaffSessionList", methods=["GET"])
  203. def get_staff_session_list():
  204. staff_id = request.args.get("staff_id")
  205. if not staff_id:
  206. return wrap_response(404, msg="staff_id is required")
  207. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  208. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  209. # check params
  210. try:
  211. page_id = int(page_id)
  212. page_size = int(page_size)
  213. except Exception as e:
  214. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  215. staff_session_list = app.session_manager.get_staff_session_list(staff_id, page_id, page_size)
  216. if not staff_session_list:
  217. return wrap_response(404, msg="staff not found")
  218. return wrap_response(200, data=staff_session_list)
  219. @app.route("/api/getStaffList", methods=["GET"])
  220. def get_staff_list():
  221. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  222. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  223. # check params
  224. try:
  225. page_id = int(page_id)
  226. page_size = int(page_size)
  227. except Exception as e:
  228. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  229. staff_list = app.user_manager.get_staff_list(page_id, page_size)
  230. if not staff_list:
  231. return wrap_response(404, msg="staff not found")
  232. return wrap_response(200, data=staff_list)
  233. @app.route("/api/getConversationList", methods=["GET"])
  234. def get_conversation_list():
  235. """
  236. 获取staff && user 私聊对话列表
  237. :return:
  238. """
  239. staff_id = request.args.get("staff_id")
  240. user_id = request.args.get("user_id")
  241. if not staff_id or not user_id:
  242. return wrap_response(404, msg="staff_id and user_id are required")
  243. page = request.args.get("page")
  244. response = app.session_manager.get_conversation_list(staff_id, user_id, page, const.DEFAULT_CONVERSATION_SIZE)
  245. return wrap_response(200, data=response)
  246. @app.route("/api/sendMessage", methods=["POST"])
  247. def send_message():
  248. return wrap_response(200, msg="暂不实现功能")
  249. @app.route("/api/quitHumanIntervention", methods=["POST"])
  250. def quit_human_intervention():
  251. """
  252. 退出人工介入状态
  253. :return:
  254. """
  255. req_data = request.json
  256. staff_id = req_data["staff_id"]
  257. user_id = req_data["user_id"]
  258. if not user_id or not staff_id:
  259. return wrap_response(404, msg="user_id and staff_id are required")
  260. if pqai_agent_server.utils.common.quit_human_intervention(user_id, staff_id):
  261. return wrap_response(200, msg="success")
  262. else:
  263. return wrap_response(500, msg="error")
  264. @app.route("/api/enterHumanIntervention", methods=["POST"])
  265. def enter_human_intervention():
  266. """
  267. 进入人工介入状态
  268. :return:
  269. """
  270. req_data = request.json
  271. staff_id = req_data["staff_id"]
  272. user_id = req_data["user_id"]
  273. if not user_id or not staff_id:
  274. return wrap_response(404, msg="user_id and staff_id are required")
  275. if pqai_agent_server.utils.common.enter_human_intervention(user_id, staff_id):
  276. return wrap_response(200, msg="success")
  277. else:
  278. return wrap_response(500, msg="error")
  279. ## Agent管理接口
  280. @app.route("/api/getNativeAgentList", methods=["GET"])
  281. def get_native_agent_list():
  282. """
  283. 获取所有的Agent列表
  284. :return:
  285. """
  286. page = request.args.get('page', 1)
  287. page_size = request.args.get('page_size', 50)
  288. create_user = request.args.get('create_user', None)
  289. update_user = request.args.get('update_user', None)
  290. offset = (int(page) - 1) * int(page_size)
  291. with app.session_maker() as session:
  292. query = session.query(AgentConfiguration) \
  293. .filter(AgentConfiguration.is_delete == 0)
  294. if create_user:
  295. query = query.filter(AgentConfiguration.create_user == create_user)
  296. if update_user:
  297. query = query.filter(AgentConfiguration.update_user == update_user)
  298. total = query.count()
  299. query = query.offset(offset).limit(int(page_size))
  300. data = query.all()
  301. ret_data = {
  302. 'total': total,
  303. 'agent_list': [
  304. {
  305. 'id': agent.id,
  306. 'name': agent.name,
  307. 'display_name': agent.display_name,
  308. 'type': agent.type,
  309. 'execution_model': agent.execution_model,
  310. 'create_user': agent.create_user,
  311. 'update_user': agent.update_user,
  312. 'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  313. 'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
  314. }
  315. for agent in data
  316. ]
  317. }
  318. return wrap_response(200, data=ret_data)
  319. @app.route("/api/getNativeAgentConfiguration", methods=["GET"])
  320. def get_native_agent_configuration():
  321. """
  322. 获取指定Agent的配置
  323. :return:
  324. """
  325. agent_id = request.args.get('agent_id')
  326. if not agent_id:
  327. return wrap_response(404, msg='agent_id is required')
  328. with app.session_maker() as session:
  329. agent = session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).first()
  330. if not agent:
  331. return wrap_response(404, msg='Agent not found')
  332. data = {
  333. 'id': agent.id,
  334. 'name': agent.name,
  335. 'display_name': agent.display_name,
  336. 'type': agent.type,
  337. 'execution_model': agent.execution_model,
  338. 'system_prompt': agent.system_prompt,
  339. 'task_prompt': agent.task_prompt,
  340. 'tools': json.loads(agent.tools),
  341. 'sub_agents': json.loads(agent.sub_agents),
  342. 'extra_params': json.loads(agent.extra_params),
  343. 'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  344. 'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
  345. }
  346. return wrap_response(200, data=data)
  347. @app.route("/api/saveNativeAgentConfiguration", methods=["POST"])
  348. def save_native_agent_configuration():
  349. """
  350. 保存Agent配置
  351. :return:
  352. """
  353. req_data = request.json
  354. agent_id = req_data.get('agent_id', None)
  355. name = req_data.get('name')
  356. display_name = req_data.get('display_name', None)
  357. type_ = req_data.get('type', 0)
  358. execution_model = req_data.get('execution_model', None)
  359. system_prompt = req_data.get('system_prompt', None)
  360. task_prompt = req_data.get('task_prompt', None)
  361. tools = json.dumps(req_data.get('tools', []))
  362. sub_agents = json.dumps(req_data.get('sub_agents', []))
  363. extra_params = req_data.get('extra_params', {})
  364. operate_user = req_data.get('user', None)
  365. if isinstance(extra_params, dict):
  366. extra_params = json.dumps(extra_params)
  367. elif isinstance(extra_params, str):
  368. try:
  369. json.loads(extra_params)
  370. except json.JSONDecodeError:
  371. return wrap_response(400, msg='extra_params should be a valid JSON object or string')
  372. if not extra_params:
  373. extra_params = '{}'
  374. if not name:
  375. return wrap_response(400, msg='name is required')
  376. with app.session_maker() as session:
  377. if agent_id:
  378. agent_id = int(agent_id)
  379. agent = session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).first()
  380. if not agent:
  381. return wrap_response(404, msg='Agent not found')
  382. agent.name = name
  383. agent.display_name = display_name
  384. agent.type = type_
  385. agent.execution_model = execution_model
  386. agent.system_prompt = system_prompt
  387. agent.task_prompt = task_prompt
  388. agent.tools = tools
  389. agent.sub_agents = sub_agents
  390. agent.extra_params = extra_params
  391. agent.update_user = operate_user
  392. else:
  393. agent = AgentConfiguration(
  394. name=name,
  395. display_name=display_name,
  396. type=type_,
  397. execution_model=execution_model,
  398. system_prompt=system_prompt,
  399. task_prompt=task_prompt,
  400. tools=tools,
  401. sub_agents=sub_agents,
  402. extra_params=extra_params,
  403. create_user=operate_user,
  404. update_user=operate_user
  405. )
  406. session.add(agent)
  407. session.commit()
  408. return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
  409. @app.route("/api/deleteNativeAgentConfiguration", methods=["POST"])
  410. def delete_native_agent_configuration():
  411. """
  412. 删除指定Agent配置(软删除,设置is_delete=1)
  413. :return:
  414. """
  415. req_data = request.json
  416. agent_id = req_data.get('agent_id', None)
  417. if not agent_id:
  418. return wrap_response(400, msg='agent_id is required')
  419. try:
  420. agent_id = int(agent_id)
  421. except ValueError:
  422. return wrap_response(400, msg='agent_id must be an integer')
  423. with app.session_maker() as session:
  424. agent = session.query(AgentConfiguration).filter(
  425. AgentConfiguration.id == agent_id,
  426. AgentConfiguration.is_delete == 0
  427. ).first()
  428. if not agent:
  429. return wrap_response(404, msg='Agent not found')
  430. agent.is_delete = 1
  431. session.commit()
  432. return wrap_response(200, msg='Agent configuration deleted successfully')
  433. @app.route("/api/getModuleList", methods=["GET"])
  434. def get_module_list():
  435. """
  436. 获取所有的模块列表,支持分页查询
  437. :return:
  438. """
  439. page = request.args.get('page', 1)
  440. page_size = request.args.get('page_size', 50)
  441. try:
  442. page = int(page)
  443. page_size = int(page_size)
  444. except Exception as e:
  445. return wrap_response(400, msg="Invalid parameter: {}".format(e))
  446. offset = (page - 1) * page_size
  447. with app.session_maker() as session:
  448. query = session.query(ServiceModule).filter(ServiceModule.is_delete == 0)
  449. total = query.count()
  450. modules = query.offset(offset).limit(page_size).all()
  451. ret_data = {
  452. 'total': total,
  453. 'module_list': [
  454. {
  455. 'id': module.id,
  456. 'name': module.name,
  457. 'display_name': module.display_name,
  458. 'default_agent_type': module.default_agent_type,
  459. 'default_agent_id': module.default_agent_id,
  460. 'create_time': module.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  461. 'update_time': module.update_time.strftime('%Y-%m-%d %H:%M:%S')
  462. }
  463. for module in modules
  464. ]
  465. }
  466. return wrap_response(200, data=ret_data)
  467. @app.route("/api/getModuleConfiguration", methods=["GET"])
  468. def get_module_configuration():
  469. """
  470. 获取指定模块的配置
  471. :return:
  472. """
  473. module_id = request.args.get('module_id')
  474. if not module_id:
  475. return wrap_response(404, msg='module_id is required')
  476. with app.session_maker() as session:
  477. module = session.query(ServiceModule).filter(ServiceModule.id == module_id).first()
  478. if not module:
  479. return wrap_response(404, msg='Module not found')
  480. data = {
  481. 'id': module.id,
  482. 'name': module.name,
  483. 'display_name': module.display_name,
  484. 'default_agent_type': module.default_agent_type,
  485. 'default_agent_id': module.default_agent_id,
  486. 'create_time': module.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  487. 'updated_time': module.updated_time.strftime('%Y-%m-%d %H:%M:%S')
  488. }
  489. return wrap_response(200, data=data)
  490. @app.route("/api/saveModuleConfiguration", methods=["POST"])
  491. def save_module_configuration():
  492. """
  493. 保存模块配置
  494. :return:
  495. """
  496. req_data = request.json
  497. module_id = req_data.get('module_id', None)
  498. name = req_data.get('name')
  499. display_name = req_data.get('display_name', None)
  500. default_agent_type = req_data.get('default_agent_type', 0)
  501. default_agent_id = req_data.get('default_agent_id', None)
  502. if not name:
  503. return wrap_response(400, msg='name is required')
  504. with app.session_maker() as session:
  505. if module_id:
  506. module_id = int(module_id)
  507. module = session.query(ServiceModule).filter(ServiceModule.id == module_id).first()
  508. if not module:
  509. return wrap_response(404, msg='Module not found')
  510. module.name = name
  511. module.display_name = display_name
  512. module.default_agent_type = default_agent_type
  513. module.default_agent_id = default_agent_id
  514. else:
  515. module = ServiceModule(
  516. name=name,
  517. display_name=display_name,
  518. default_agent_type=default_agent_type,
  519. default_agent_id=default_agent_id
  520. )
  521. session.add(module)
  522. session.commit()
  523. return wrap_response(200, msg='Module configuration saved successfully', data={'id': module.id})
  524. @app.route("/api/getTestTaskList", methods=["GET"])
  525. def get_test_task_list():
  526. """
  527. 获取单元测试任务列表
  528. :return:
  529. """
  530. page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
  531. page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
  532. try:
  533. page_num = int(page_num)
  534. page_size = int(page_size)
  535. except Exception as e:
  536. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  537. response = app.task_manager.get_test_task_list(page_num, page_size)
  538. return wrap_response(200, data=response)
  539. @app.route("/api/getTestTaskConversations", methods=["GET"])
  540. def get_test_task_conversations():
  541. """
  542. 获取单元测试对话任务列表
  543. :return:
  544. """
  545. task_id = request.args.get("taskId", None)
  546. if not task_id:
  547. return wrap_response(404, msg='task_id is required')
  548. page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
  549. page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
  550. try:
  551. page_num = int(page_num)
  552. page_size = int(page_size)
  553. except Exception as e:
  554. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  555. response = app.task_manager.get_test_task_conversations(int(task_id), page_num, page_size)
  556. return wrap_response(200, data=response)
  557. @app.route("/api/createTestTask", methods=["POST"])
  558. def create_test_task():
  559. """
  560. 创建单元测试任务
  561. :return:
  562. """
  563. req_data = request.json
  564. agent_id = req_data.get('agentId', None)
  565. module_id = req_data.get('moduleId', None)
  566. evaluate_type = req_data.get('evaluateType', None)
  567. if not agent_id:
  568. return wrap_response(404, msg='agent id is required')
  569. if not module_id:
  570. return wrap_response(404, msg='module id is required')
  571. if not evaluate_type:
  572. return wrap_response(404, msg='evaluate_type id is required')
  573. app.task_manager.create_task(agent_id, module_id, evaluate_type)
  574. return wrap_response(200)
  575. @app.route("/api/stopTestTask", methods=["POST"])
  576. def stop_test_task():
  577. """
  578. 停止单元测试任务
  579. :return:
  580. """
  581. req_data = request.json
  582. task_id = req_data.get('taskId', None)
  583. if not task_id:
  584. return wrap_response(400, msg='task id is required')
  585. task = app.task_manager.get_task(task_id)
  586. if task.status not in (TestTaskStatus.NOT_STARTED.value, TestTaskStatus.IN_PROGRESS.value):
  587. return wrap_response(400, msg='task status is invalid')
  588. app.task_manager.cancel_task(task_id)
  589. return wrap_response(200)
  590. @app.route("/api/resumeTestTask", methods=["POST"])
  591. def resume_test_task():
  592. """
  593. 恢复停止的单元测试任务
  594. :return:
  595. """
  596. req_data = request.json
  597. task_id = req_data.get('taskId', None)
  598. if not task_id:
  599. return wrap_response(400, msg='task id is required')
  600. task = app.task_manager.get_task(task_id)
  601. if task.status != TestTaskStatus.CANCELLED.value:
  602. return wrap_response(400, msg='task status is invalid')
  603. app.task_manager.resume_task(task_id)
  604. return wrap_response(200)
  605. @app.route("/api/getEvaluateType", methods=["GET"])
  606. def get_evaluate_type():
  607. """
  608. 获取评估类型
  609. :return:
  610. """
  611. name_desc_list = [
  612. {
  613. "type": item.value,
  614. "desc": item.description
  615. }
  616. for item in EvaluateType]
  617. return wrap_response(code=200, data=name_desc_list)
  618. @app.route("/api/getDatasetList", methods=["GET"])
  619. def get_dataset_list():
  620. """
  621. 获取数据集列表
  622. :return:
  623. """
  624. page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
  625. page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
  626. try:
  627. page_num = int(page_num)
  628. page_size = int(page_size)
  629. except Exception as e:
  630. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  631. response = app.dataset_service.get_dataset_list(page_num, page_size)
  632. return wrap_response(200, data=response)
  633. @app.route("/api/getConversationDataList", methods=["GET"])
  634. def get_conversation_data_list():
  635. """
  636. 获取对话列表
  637. :return:
  638. """
  639. dataset_id = request.args.get("datasetId", None)
  640. if not dataset_id:
  641. return wrap_response(404, msg='dataset_id is required')
  642. page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
  643. page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
  644. try:
  645. page_num = int(page_num)
  646. page_size = int(page_size)
  647. except Exception as e:
  648. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  649. response = app.dataset_service.get_conversation_data_list(int(dataset_id), page_num, page_size)
  650. return wrap_response(200, data=response)
  651. @app.route("/api/getToolList", methods=["GET"])
  652. def get_tool_list():
  653. """
  654. 获取所有的工具列表
  655. :return:
  656. """
  657. tools = []
  658. for tool_name, tool in global_tool_map.items():
  659. tools.append({
  660. 'name': tool_name,
  661. 'description': tool.get_function_description(),
  662. 'parameters': tool.parameters if hasattr(tool, 'parameters') else {}
  663. })
  664. return wrap_response(200, data=tools)
  665. @app.route("/api/getModuleAgentTypes", methods=["GET"])
  666. def get_agent_types():
  667. """
  668. 获取所有的Agent类型
  669. :return:
  670. """
  671. agent_types = [
  672. {'type': 0, 'display_name': '原生'},
  673. {'type': 1, 'display_name': 'Coze'}
  674. ]
  675. return wrap_response(200, data=agent_types)
  676. @app.route("/api/createAgentTask", methods=["POST"])
  677. def create_agent_task():
  678. """
  679. 创建agent执行任务
  680. :return:
  681. """
  682. req_data = request.json
  683. agent_id = req_data.get('agentId', None)
  684. task_prompt = req_data.get('taskPrompt', None)
  685. if not agent_id:
  686. return wrap_response(404, msg='agent id is required')
  687. if not task_prompt:
  688. return wrap_response(404, msg='task_prompt is required')
  689. app.agent_task_manager.create_task(agent_id, task_prompt)
  690. return wrap_response(200)
  691. @app.route("/api/getAgentTaskDetail", methods=["GET"])
  692. def get_agent_task_detail():
  693. """
  694. 查询agent执行任务详情
  695. :return:
  696. """
  697. agent_task_id = request.args.get("agentTaskId", None)
  698. if not agent_task_id:
  699. return wrap_response(404, msg='agent_task_id is required')
  700. response = app.agent_task_manager.get_agent_task_detail(int(agent_task_id))
  701. return wrap_response(200, data=response)
  702. @app.errorhandler(werkzeug.exceptions.BadRequest)
  703. def handle_bad_request(e):
  704. logger.error(e)
  705. return wrap_response(400, msg='Bad Request: {}'.format(e.description))
  706. if __name__ == '__main__':
  707. parser = ArgumentParser()
  708. parser.add_argument('--prod', action='store_true')
  709. parser.add_argument('--host', default='127.0.0.1')
  710. parser.add_argument('--port', type=int, default=8083)
  711. parser.add_argument('--log-level', default='INFO')
  712. args = parser.parse_args()
  713. config = configs.get()
  714. logging_level = logging.getLevelName(args.log_level)
  715. setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
  716. # set db config
  717. agent_db_config = config['database']['ai_agent']
  718. growth_db_config = config['database']['growth']
  719. user_db_config = config['storage']['user']
  720. staff_db_config = config['storage']['staff']
  721. agent_state_db_config = config['storage']['agent_state']
  722. chat_history_db_config = config['storage']['chat_history']
  723. # init user manager
  724. user_manager = MySQLUserManager(agent_db_config, user_db_config['table'], staff_db_config['table'])
  725. app.user_manager = user_manager
  726. # init session manager
  727. session_manager = MySQLSessionManager(
  728. db_config=agent_db_config,
  729. staff_table=staff_db_config['table'],
  730. user_table=user_db_config['table'],
  731. agent_state_table=agent_state_db_config['table'],
  732. chat_history_table=chat_history_db_config['table']
  733. )
  734. app.session_manager = session_manager
  735. agent_db_engine = create_ai_agent_db_engine()
  736. app.session_maker = sessionmaker(bind=agent_db_engine)
  737. dataset_service = DatasetService(session_maker=sessionmaker(bind=agent_db_engine))
  738. app.dataset_service = dataset_service
  739. task_manager = TaskManager(session_maker=sessionmaker(bind=agent_db_engine), dataset_service=dataset_service)
  740. app.task_manager = task_manager
  741. app.task_manager.recover_tasks()
  742. agent_task_manager = AgentTaskManager(session_maker=sessionmaker(bind=agent_db_engine))
  743. app.agent_task_manager = agent_task_manager
  744. app.agent_task_manager.recover_tasks()
  745. wecom_db_config = config['storage']['user_relation']
  746. user_relation_manager = MySQLUserRelationManager(
  747. agent_db_config, growth_db_config,
  748. config['storage']['staff']['table'],
  749. user_db_config['table'],
  750. wecom_db_config['table']['staff'],
  751. wecom_db_config['table']['relation'],
  752. wecom_db_config['table']['user']
  753. )
  754. app.user_relation_manager = user_relation_manager
  755. app.history_dialogue_service = HistoryDialogueService(
  756. config['storage']['history_dialogue']['api_base_url']
  757. )
  758. app.run(debug=not args.prod, host=args.host, port=args.port)