api_server.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838
  1. #! /usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # vim:fenc=utf-8
  4. import json
  5. import logging
  6. from argparse import ArgumentParser
  7. import werkzeug.exceptions
  8. from flask import Flask, request, jsonify
  9. from sqlalchemy.orm import sessionmaker
  10. import pqai_agent_server
  11. import pqai_agent_server.utils
  12. from pqai_agent import chat_service, prompt_templates
  13. from pqai_agent import configs
  14. from pqai_agent.chat_service import OpenAICompatible
  15. from pqai_agent.data_models.agent_configuration import AgentConfiguration
  16. from pqai_agent.data_models.service_module import ServiceModule
  17. from pqai_agent.history_dialogue_service import HistoryDialogueService
  18. from pqai_agent.logging import logger, setup_root_logger
  19. from pqai_agent.toolkit import global_tool_map
  20. from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
  21. from pqai_agent.utils.db_utils import create_ai_agent_db_engine
  22. from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
  23. from pqai_agent_server.const import AgentApiConst
  24. from pqai_agent_server.const.status_enum import TestTaskStatus
  25. from pqai_agent_server.const.type_enum import EvaluateType
  26. from pqai_agent_server.dataset_service import DatasetService
  27. from pqai_agent_server.models import MySQLSessionManager
  28. from pqai_agent_server.task_server import TaskManager
  29. from pqai_agent_server.utils import (
  30. run_extractor_prompt,
  31. run_chat_prompt,
  32. run_response_type_prompt,
  33. )
  34. from pqai_agent_server.utils import wrap_response
  35. app = Flask('agent_api_server')
  36. const = AgentApiConst()
  37. @app.route('/api/listStaffs', methods=['GET'])
  38. def list_staffs():
  39. staff_data = app.user_relation_manager.list_staffs()
  40. return wrap_response(200, data=staff_data)
  41. @app.route('/api/getStaffProfile', methods=['GET'])
  42. def get_staff_profile():
  43. staff_id = request.args['staff_id']
  44. profile = app.user_manager.get_staff_profile(staff_id)
  45. if not profile:
  46. return wrap_response(404, msg='staff not found')
  47. else:
  48. return wrap_response(200, data=profile)
  49. @app.route('/api/getUserProfile', methods=['GET'])
  50. def get_user_profile():
  51. user_id = request.args['user_id']
  52. profile = app.user_manager.get_user_profile(user_id)
  53. if not profile:
  54. resp = {
  55. 'code': 404,
  56. 'msg': 'user not found'
  57. }
  58. else:
  59. resp = {
  60. 'code': 200,
  61. 'msg': 'success',
  62. 'data': profile
  63. }
  64. return jsonify(resp)
  65. @app.route('/api/listUsers', methods=['GET'])
  66. def list_users():
  67. user_name = request.args.get('user_name', None)
  68. user_union_id = request.args.get('user_union_id', None)
  69. if not user_name and not user_union_id:
  70. resp = {
  71. 'code': 400,
  72. 'msg': 'user_name or user_union_id is required'
  73. }
  74. return jsonify(resp)
  75. data = app.user_manager.list_users(user_name=user_name, user_union_id=user_union_id)
  76. return jsonify({'code': 200, 'data': data})
  77. @app.route('/api/getDialogueHistory', methods=['GET'])
  78. def get_dialogue_history():
  79. staff_id = request.args['staff_id']
  80. user_id = request.args['user_id']
  81. recent_minutes = int(request.args.get('recent_minutes', 1440))
  82. dialogue_history = app.history_dialogue_service.get_dialogue_history(staff_id, user_id, recent_minutes)
  83. return jsonify({'code': 200, 'data': dialogue_history})
  84. @app.route('/api/listModels', methods=['GET'])
  85. def list_models():
  86. models = {
  87. "deepseek-chat": chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
  88. "gpt-4o": chat_service.OPENAI_MODEL_GPT_4o,
  89. "gpt-4o-mini": chat_service.OPENAI_MODEL_GPT_4o_mini,
  90. "doubao-pro-32k": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  91. "doubao-pro-1.5": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
  92. "doubao-1.5-vision-pro": chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
  93. "openrouter-gemini-2.5-pro": chat_service.OPENROUTER_MODEL_GEMINI_2_5_PRO,
  94. }
  95. ret_data = [
  96. {
  97. 'model_type': 'openai_compatible',
  98. 'model_name': model_name,
  99. 'display_name': f"{model_display_name} ({OpenAICompatible.get_price(model_name).get_cny_brief()})"
  100. }
  101. for model_display_name, model_name in models.items()
  102. ]
  103. return wrap_response(200, data=ret_data)
  104. @app.route('/api/listScenes', methods=['GET'])
  105. def list_scenes():
  106. scenes = [
  107. {'scene': 'greeting', 'display_name': '问候'},
  108. {'scene': 'chitchat', 'display_name': '闲聊'},
  109. {'scene': 'profile_extractor', 'display_name': '画像提取'},
  110. {'scene': 'response_type_detector', 'display_name': '回复模态判断'},
  111. {'scene': 'custom_debugging', 'display_name': '自定义调试场景'}
  112. ]
  113. return wrap_response(200, data=scenes)
  114. @app.route('/api/getBasePrompt', methods=['GET'])
  115. def get_base_prompt():
  116. scene = request.args['scene']
  117. prompt_map = {
  118. 'greeting': prompt_templates.GENERAL_GREETING_PROMPT,
  119. 'chitchat': prompt_templates.CHITCHAT_PROMPT_COZE,
  120. 'profile_extractor': prompt_templates.USER_PROFILE_EXTRACT_PROMPT_V2,
  121. 'response_type_detector': prompt_templates.RESPONSE_TYPE_DETECT_PROMPT,
  122. 'custom_debugging': '',
  123. }
  124. model_map = {
  125. 'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  126. 'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
  127. 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
  128. 'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
  129. 'custom_debugging': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH
  130. }
  131. if scene not in prompt_map:
  132. return wrap_response(404, msg='scene not found')
  133. data = {
  134. 'model_name': model_map[scene],
  135. 'content': prompt_map[scene]
  136. }
  137. return wrap_response(200, data=data)
  138. @app.route('/api/runPrompt', methods=['POST'])
  139. def run_prompt():
  140. try:
  141. req_data = request.json
  142. logger.debug(req_data)
  143. scene = req_data['scene']
  144. if scene == 'profile_extractor':
  145. response = run_extractor_prompt(req_data)
  146. return wrap_response(200, data=response)
  147. elif scene == 'response_type_detector':
  148. response = run_response_type_prompt(req_data)
  149. return wrap_response(200, data=response.choices[0].message.content)
  150. else:
  151. response = run_chat_prompt(req_data)
  152. return wrap_response(200, data=response.choices[0].message.content)
  153. except Exception as e:
  154. logger.error(e)
  155. return wrap_response(500, msg='Error: {}'.format(e))
  156. @app.route('/api/formatForPrompt', methods=['POST'])
  157. def format_data_for_prompt():
  158. try:
  159. req_data = request.json
  160. content = req_data['content']
  161. format_type = req_data['format_type']
  162. if format_type == 'staff_profile':
  163. if not isinstance(content, dict):
  164. return wrap_response(400, msg='staff_profile should be a dict')
  165. response = format_agent_profile(content)
  166. elif format_type == 'user_profile':
  167. if not isinstance(content, dict):
  168. return wrap_response(400, msg='user_profile should be a dict')
  169. response = format_user_profile(content)
  170. elif format_type == 'dialogue':
  171. if not isinstance(content, list):
  172. return wrap_response(400, msg='dialogue should be a list')
  173. from pqai_agent_server.utils.prompt_util import format_dialogue_history
  174. response = format_dialogue_history(content)
  175. else:
  176. return wrap_response(400, msg='Invalid format_type')
  177. return wrap_response(200, data=response)
  178. except Exception as e:
  179. logger.error(e)
  180. return wrap_response(500, msg='Error: {}'.format(e))
  181. @app.route("/api/healthCheck", methods=["GET"])
  182. def health_check():
  183. return wrap_response(200, msg="OK")
  184. @app.route("/api/getStaffSessionSummary", methods=["GET"])
  185. def get_staff_session_summary():
  186. staff_id = request.args.get("staff_id")
  187. status = request.args.get("status", const.DEFAULT_STAFF_STATUS)
  188. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  189. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  190. # check params
  191. try:
  192. page_id = int(page_id)
  193. page_size = int(page_size)
  194. status = int(status)
  195. except Exception as e:
  196. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  197. staff_session_summary = app.session_manager.get_staff_sessions_summary(
  198. staff_id, page_id, page_size, status
  199. )
  200. if not staff_session_summary:
  201. return wrap_response(404, msg="staff not found")
  202. else:
  203. return wrap_response(200, data=staff_session_summary)
  204. @app.route("/api/getStaffSessionList", methods=["GET"])
  205. def get_staff_session_list():
  206. staff_id = request.args.get("staff_id")
  207. if not staff_id:
  208. return wrap_response(404, msg="staff_id is required")
  209. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  210. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  211. # check params
  212. try:
  213. page_id = int(page_id)
  214. page_size = int(page_size)
  215. except Exception as e:
  216. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  217. staff_session_list = app.session_manager.get_staff_session_list(staff_id, page_id, page_size)
  218. if not staff_session_list:
  219. return wrap_response(404, msg="staff not found")
  220. return wrap_response(200, data=staff_session_list)
  221. @app.route("/api/getStaffList", methods=["GET"])
  222. def get_staff_list():
  223. page_size = request.args.get("page_size", const.DEFAULT_PAGE_SIZE)
  224. page_id = request.args.get("page_id", const.DEFAULT_PAGE_ID)
  225. # check params
  226. try:
  227. page_id = int(page_id)
  228. page_size = int(page_size)
  229. except Exception as e:
  230. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  231. staff_list = app.user_manager.get_staff_list(page_id, page_size)
  232. if not staff_list:
  233. return wrap_response(404, msg="staff not found")
  234. return wrap_response(200, data=staff_list)
  235. @app.route("/api/getConversationList", methods=["GET"])
  236. def get_conversation_list():
  237. """
  238. 获取staff && user 私聊对话列表
  239. :return:
  240. """
  241. staff_id = request.args.get("staff_id")
  242. user_id = request.args.get("user_id")
  243. if not staff_id or not user_id:
  244. return wrap_response(404, msg="staff_id and user_id are required")
  245. page = request.args.get("page")
  246. response = app.session_manager.get_conversation_list(staff_id, user_id, page, const.DEFAULT_CONVERSATION_SIZE)
  247. return wrap_response(200, data=response)
  248. @app.route("/api/sendMessage", methods=["POST"])
  249. def send_message():
  250. return wrap_response(200, msg="暂不实现功能")
  251. @app.route("/api/quitHumanIntervention", methods=["POST"])
  252. def quit_human_intervention():
  253. """
  254. 退出人工介入状态
  255. :return:
  256. """
  257. req_data = request.json
  258. staff_id = req_data["staff_id"]
  259. user_id = req_data["user_id"]
  260. if not user_id or not staff_id:
  261. return wrap_response(404, msg="user_id and staff_id are required")
  262. if pqai_agent_server.utils.common.quit_human_intervention(user_id, staff_id):
  263. return wrap_response(200, msg="success")
  264. else:
  265. return wrap_response(500, msg="error")
  266. @app.route("/api/enterHumanIntervention", methods=["POST"])
  267. def enter_human_intervention():
  268. """
  269. 进入人工介入状态
  270. :return:
  271. """
  272. req_data = request.json
  273. staff_id = req_data["staff_id"]
  274. user_id = req_data["user_id"]
  275. if not user_id or not staff_id:
  276. return wrap_response(404, msg="user_id and staff_id are required")
  277. if pqai_agent_server.utils.common.enter_human_intervention(user_id, staff_id):
  278. return wrap_response(200, msg="success")
  279. else:
  280. return wrap_response(500, msg="error")
  281. ## Agent管理接口
  282. @app.route("/api/getNativeAgentList", methods=["GET"])
  283. def get_native_agent_list():
  284. """
  285. 获取所有的Agent列表
  286. :return:
  287. """
  288. page = request.args.get('page', 1)
  289. page_size = request.args.get('page_size', 50)
  290. create_user = request.args.get('create_user', None)
  291. update_user = request.args.get('update_user', None)
  292. offset = (int(page) - 1) * int(page_size)
  293. with app.session_maker() as session:
  294. query = session.query(AgentConfiguration) \
  295. .filter(AgentConfiguration.is_delete == 0)
  296. if create_user:
  297. query = query.filter(AgentConfiguration.create_user == create_user)
  298. if update_user:
  299. query = query.filter(AgentConfiguration.update_user == update_user)
  300. total = query.count()
  301. query = query.offset(offset).limit(int(page_size))
  302. data = query.all()
  303. ret_data = {
  304. 'total': total,
  305. 'agent_list': [
  306. {
  307. 'id': agent.id,
  308. 'name': agent.name,
  309. 'display_name': agent.display_name,
  310. 'type': agent.type,
  311. 'execution_model': agent.execution_model,
  312. 'create_user': agent.create_user,
  313. 'update_user': agent.update_user,
  314. 'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  315. 'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
  316. }
  317. for agent in data
  318. ]
  319. }
  320. return wrap_response(200, data=ret_data)
  321. @app.route("/api/getNativeAgentConfiguration", methods=["GET"])
  322. def get_native_agent_configuration():
  323. """
  324. 获取指定Agent的配置
  325. :return:
  326. """
  327. agent_id = request.args.get('agent_id')
  328. if not agent_id:
  329. return wrap_response(404, msg='agent_id is required')
  330. with app.session_maker() as session:
  331. agent = session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).first()
  332. if not agent:
  333. return wrap_response(404, msg='Agent not found')
  334. data = {
  335. 'id': agent.id,
  336. 'name': agent.name,
  337. 'display_name': agent.display_name,
  338. 'type': agent.type,
  339. 'execution_model': agent.execution_model,
  340. 'system_prompt': agent.system_prompt,
  341. 'task_prompt': agent.task_prompt,
  342. 'tools': json.loads(agent.tools),
  343. 'sub_agents': json.loads(agent.sub_agents),
  344. 'extra_params': json.loads(agent.extra_params),
  345. 'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  346. 'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
  347. }
  348. return wrap_response(200, data=data)
  349. @app.route("/api/saveNativeAgentConfiguration", methods=["POST"])
  350. def save_native_agent_configuration():
  351. """
  352. 保存Agent配置
  353. :return:
  354. """
  355. req_data = request.json
  356. agent_id = req_data.get('agent_id', None)
  357. name = req_data.get('name')
  358. display_name = req_data.get('display_name', None)
  359. type_ = req_data.get('type', 0)
  360. execution_model = req_data.get('execution_model', None)
  361. system_prompt = req_data.get('system_prompt', None)
  362. task_prompt = req_data.get('task_prompt', None)
  363. tools = json.dumps(req_data.get('tools', []))
  364. sub_agents = json.dumps(req_data.get('sub_agents', []))
  365. extra_params = req_data.get('extra_params', {})
  366. operate_user = req_data.get('user', None)
  367. if isinstance(extra_params, dict):
  368. extra_params = json.dumps(extra_params)
  369. elif isinstance(extra_params, str):
  370. try:
  371. json.loads(extra_params)
  372. except json.JSONDecodeError:
  373. return wrap_response(400, msg='extra_params should be a valid JSON object or string')
  374. if not extra_params:
  375. extra_params = '{}'
  376. if not name:
  377. return wrap_response(400, msg='name is required')
  378. with app.session_maker() as session:
  379. if agent_id:
  380. agent_id = int(agent_id)
  381. agent = session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).first()
  382. if not agent:
  383. return wrap_response(404, msg='Agent not found')
  384. agent.name = name
  385. agent.display_name = display_name
  386. agent.type = type_
  387. agent.execution_model = execution_model
  388. agent.system_prompt = system_prompt
  389. agent.task_prompt = task_prompt
  390. agent.tools = tools
  391. agent.sub_agents = sub_agents
  392. agent.extra_params = extra_params
  393. agent.update_user = operate_user
  394. else:
  395. agent = AgentConfiguration(
  396. name=name,
  397. display_name=display_name,
  398. type=type_,
  399. execution_model=execution_model,
  400. system_prompt=system_prompt,
  401. task_prompt=task_prompt,
  402. tools=tools,
  403. sub_agents=sub_agents,
  404. extra_params=extra_params,
  405. create_user=operate_user,
  406. update_user=operate_user
  407. )
  408. session.add(agent)
  409. session.commit()
  410. return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
  411. @app.route("/api/deleteNativeAgentConfiguration", methods=["POST"])
  412. def delete_native_agent_configuration():
  413. """
  414. 删除指定Agent配置(软删除,设置is_delete=1)
  415. :return:
  416. """
  417. req_data = request.json
  418. agent_id = req_data.get('agent_id', None)
  419. if not agent_id:
  420. return wrap_response(400, msg='agent_id is required')
  421. try:
  422. agent_id = int(agent_id)
  423. except ValueError:
  424. return wrap_response(400, msg='agent_id must be an integer')
  425. with app.session_maker() as session:
  426. agent = session.query(AgentConfiguration).filter(
  427. AgentConfiguration.id == agent_id,
  428. AgentConfiguration.is_delete == 0
  429. ).first()
  430. if not agent:
  431. return wrap_response(404, msg='Agent not found')
  432. agent.is_delete = 1
  433. session.commit()
  434. return wrap_response(200, msg='Agent configuration deleted successfully')
  435. @app.route("/api/getModuleList", methods=["GET"])
  436. def get_module_list():
  437. """
  438. 获取所有的模块列表,支持分页查询
  439. :return:
  440. """
  441. page = request.args.get('page', 1)
  442. page_size = request.args.get('page_size', 50)
  443. try:
  444. page = int(page)
  445. page_size = int(page_size)
  446. except Exception as e:
  447. return wrap_response(400, msg="Invalid parameter: {}".format(e))
  448. offset = (page - 1) * page_size
  449. with app.session_maker() as session:
  450. query = session.query(
  451. ServiceModule,
  452. AgentConfiguration.name.label("default_agent_name")
  453. ).outerjoin(
  454. AgentConfiguration,
  455. ServiceModule.default_agent_id == AgentConfiguration.id
  456. ).filter(ServiceModule.is_delete == 0)
  457. total = query.count()
  458. modules = query.offset(offset).limit(page_size).all()
  459. ret_data = {
  460. 'total': total,
  461. 'module_list': [
  462. {
  463. 'id': module.id,
  464. 'name': module.name,
  465. 'display_name': module.display_name,
  466. 'default_agent_type': module.default_agent_type,
  467. 'default_agent_id': module.default_agent_id,
  468. 'default_agent_name': default_agent_name,
  469. 'create_time': module.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  470. 'update_time': module.update_time.strftime('%Y-%m-%d %H:%M:%S')
  471. }
  472. for module, default_agent_name in modules
  473. ]
  474. }
  475. return wrap_response(200, data=ret_data)
  476. @app.route("/api/getModuleConfiguration", methods=["GET"])
  477. def get_module_configuration():
  478. """
  479. 获取指定模块的配置
  480. :return:
  481. """
  482. module_id = request.args.get('module_id')
  483. if not module_id:
  484. return wrap_response(404, msg='module_id is required')
  485. with app.session_maker() as session:
  486. module = session.query(ServiceModule).filter(ServiceModule.id == module_id).first()
  487. if not module:
  488. return wrap_response(404, msg='Module not found')
  489. data = {
  490. 'id': module.id,
  491. 'name': module.name,
  492. 'display_name': module.display_name,
  493. 'default_agent_type': module.default_agent_type,
  494. 'default_agent_id': module.default_agent_id,
  495. 'create_time': module.create_time.strftime('%Y-%m-%d %H:%M:%S'),
  496. 'updated_time': module.updated_time.strftime('%Y-%m-%d %H:%M:%S')
  497. }
  498. return wrap_response(200, data=data)
  499. @app.route("/api/saveModuleConfiguration", methods=["POST"])
  500. def save_module_configuration():
  501. """
  502. 保存模块配置
  503. :return:
  504. """
  505. req_data = request.json
  506. module_id = req_data.get('module_id', None)
  507. name = req_data.get('name')
  508. display_name = req_data.get('display_name', None)
  509. default_agent_type = req_data.get('default_agent_type', 0)
  510. default_agent_id = req_data.get('default_agent_id', None)
  511. if not name:
  512. return wrap_response(400, msg='name is required')
  513. with app.session_maker() as session:
  514. if module_id:
  515. module_id = int(module_id)
  516. module = session.query(ServiceModule).filter(ServiceModule.id == module_id).first()
  517. if not module:
  518. return wrap_response(404, msg='Module not found')
  519. module.name = name
  520. module.display_name = display_name
  521. module.default_agent_type = default_agent_type
  522. module.default_agent_id = default_agent_id
  523. else:
  524. module = ServiceModule(
  525. name=name,
  526. display_name=display_name,
  527. default_agent_type=default_agent_type,
  528. default_agent_id=default_agent_id
  529. )
  530. session.add(module)
  531. session.commit()
  532. return wrap_response(200, msg='Module configuration saved successfully', data={'id': module.id})
  533. @app.route("/api/getTestTaskList", methods=["GET"])
  534. def get_test_task_list():
  535. """
  536. 获取单元测试任务列表
  537. :return:
  538. """
  539. page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
  540. page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
  541. try:
  542. page_num = int(page_num)
  543. page_size = int(page_size)
  544. except Exception as e:
  545. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  546. response = app.task_manager.get_test_task_list(page_num, page_size)
  547. return wrap_response(200, data=response)
  548. @app.route("/api/getTestTaskConversations", methods=["GET"])
  549. def get_test_task_conversations():
  550. """
  551. 获取单元测试对话任务列表
  552. :return:
  553. """
  554. task_id = request.args.get("taskId", None)
  555. if not task_id:
  556. return wrap_response(404, msg='task_id is required')
  557. page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
  558. page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
  559. try:
  560. page_num = int(page_num)
  561. page_size = int(page_size)
  562. except Exception as e:
  563. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  564. response = app.task_manager.get_test_task_conversations(int(task_id), page_num, page_size)
  565. return wrap_response(200, data=response)
  566. @app.route("/api/createTestTask", methods=["POST"])
  567. def create_test_task():
  568. """
  569. 创建单元测试任务
  570. :return:
  571. """
  572. req_data = request.json
  573. agent_id = req_data.get('agentId', None)
  574. module_id = req_data.get('moduleId', None)
  575. evaluate_type = req_data.get('evaluateType', None)
  576. if not agent_id:
  577. return wrap_response(404, msg='agent id is required')
  578. if not module_id:
  579. return wrap_response(404, msg='module id is required')
  580. if not evaluate_type:
  581. return wrap_response(404, msg='evaluate_type id is required')
  582. app.task_manager.create_task(agent_id, module_id, evaluate_type)
  583. return wrap_response(200)
  584. @app.route("/api/stopTestTask", methods=["POST"])
  585. def stop_test_task():
  586. """
  587. 停止单元测试任务
  588. :return:
  589. """
  590. req_data = request.json
  591. task_id = req_data.get('taskId', None)
  592. if not task_id:
  593. return wrap_response(400, msg='task id is required')
  594. task = app.task_manager.get_task(task_id)
  595. if task.status not in (TestTaskStatus.NOT_STARTED.value, TestTaskStatus.IN_PROGRESS.value):
  596. return wrap_response(400, msg='task status is invalid')
  597. app.task_manager.cancel_task(task_id)
  598. return wrap_response(200)
  599. @app.route("/api/resumeTestTask", methods=["POST"])
  600. def resume_test_task():
  601. """
  602. 恢复停止的单元测试任务
  603. :return:
  604. """
  605. req_data = request.json
  606. task_id = req_data.get('taskId', None)
  607. if not task_id:
  608. return wrap_response(400, msg='task id is required')
  609. task = app.task_manager.get_task(task_id)
  610. if task.status != TestTaskStatus.CANCELLED.value:
  611. return wrap_response(400, msg='task status is invalid')
  612. app.task_manager.resume_task(task_id)
  613. return wrap_response(200)
  614. @app.route("/api/getEvaluateType", methods=["GET"])
  615. def get_evaluate_type():
  616. """
  617. 获取评估类型
  618. :return:
  619. """
  620. name_desc_list = [
  621. {
  622. "type": item.value,
  623. "desc": item.description
  624. }
  625. for item in EvaluateType]
  626. return wrap_response(code=200, data=name_desc_list)
  627. @app.route("/api/getDatasetList", methods=["GET"])
  628. def get_dataset_list():
  629. """
  630. 获取数据集列表
  631. :return:
  632. """
  633. page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
  634. page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
  635. try:
  636. page_num = int(page_num)
  637. page_size = int(page_size)
  638. except Exception as e:
  639. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  640. response = app.dataset_service.get_dataset_list(page_num, page_size)
  641. return wrap_response(200, data=response)
  642. @app.route("/api/getConversationDataList", methods=["GET"])
  643. def get_conversation_data_list():
  644. """
  645. 获取对话列表
  646. :return:
  647. """
  648. dataset_id = request.args.get("datasetId", None)
  649. if not dataset_id:
  650. return wrap_response(404, msg='dataset_id is required')
  651. page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
  652. page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
  653. try:
  654. page_num = int(page_num)
  655. page_size = int(page_size)
  656. except Exception as e:
  657. return wrap_response(404, msg="Invalid parameter: {}".format(e))
  658. response = app.dataset_service.get_conversation_data_list(int(dataset_id), page_num, page_size)
  659. return wrap_response(200, data=response)
  660. @app.route("/api/getToolList", methods=["GET"])
  661. def get_tool_list():
  662. """
  663. 获取所有的工具列表
  664. :return:
  665. """
  666. tools = []
  667. for tool_name, tool in global_tool_map.items():
  668. tools.append({
  669. 'name': tool_name,
  670. 'description': tool.get_function_description(),
  671. 'parameters': tool.parameters if hasattr(tool, 'parameters') else {}
  672. })
  673. return wrap_response(200, data=tools)
  674. @app.route("/api/getModuleAgentTypes", methods=["GET"])
  675. def get_agent_types():
  676. """
  677. 获取所有的Agent类型
  678. :return:
  679. """
  680. agent_types = [
  681. {'type': 0, 'display_name': '原生'},
  682. {'type': 1, 'display_name': 'Coze'}
  683. ]
  684. return wrap_response(200, data=agent_types)
  685. @app.errorhandler(werkzeug.exceptions.BadRequest)
  686. def handle_bad_request(e):
  687. logger.error(e)
  688. return wrap_response(400, msg='Bad Request: {}'.format(e.description))
  689. if __name__ == '__main__':
  690. parser = ArgumentParser()
  691. parser.add_argument('--prod', action='store_true')
  692. parser.add_argument('--host', default='127.0.0.1')
  693. parser.add_argument('--port', type=int, default=8083)
  694. parser.add_argument('--log-level', default='INFO')
  695. args = parser.parse_args()
  696. config = configs.get()
  697. logging_level = logging.getLevelName(args.log_level)
  698. setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
  699. # set db config
  700. agent_db_config = config['database']['ai_agent']
  701. growth_db_config = config['database']['growth']
  702. user_db_config = config['storage']['user']
  703. staff_db_config = config['storage']['staff']
  704. agent_state_db_config = config['storage']['agent_state']
  705. chat_history_db_config = config['storage']['chat_history']
  706. # init user manager
  707. user_manager = MySQLUserManager(agent_db_config, user_db_config['table'], staff_db_config['table'])
  708. app.user_manager = user_manager
  709. # init session manager
  710. session_manager = MySQLSessionManager(
  711. db_config=agent_db_config,
  712. staff_table=staff_db_config['table'],
  713. user_table=user_db_config['table'],
  714. agent_state_table=agent_state_db_config['table'],
  715. chat_history_table=chat_history_db_config['table']
  716. )
  717. app.session_manager = session_manager
  718. agent_db_engine = create_ai_agent_db_engine()
  719. app.session_maker = sessionmaker(bind=agent_db_engine)
  720. dataset_service = DatasetService(session_maker=sessionmaker(bind=agent_db_engine))
  721. app.dataset_service = dataset_service
  722. task_manager = TaskManager(session_maker=sessionmaker(bind=agent_db_engine), dataset_service=dataset_service)
  723. app.task_manager = task_manager
  724. task_manager.recover_tasks()
  725. wecom_db_config = config['storage']['user_relation']
  726. user_relation_manager = MySQLUserRelationManager(
  727. agent_db_config, growth_db_config,
  728. config['storage']['staff']['table'],
  729. user_db_config['table'],
  730. wecom_db_config['table']['staff'],
  731. wecom_db_config['table']['relation'],
  732. wecom_db_config['table']['user']
  733. )
  734. app.user_relation_manager = user_relation_manager
  735. app.history_dialogue_service = HistoryDialogueService(
  736. config['storage']['history_dialogue']['api_base_url']
  737. )
  738. app.run(debug=not args.prod, host=args.host, port=args.port)