|
@@ -1,37 +1,44 @@
|
|
|
#! /usr/bin/env python
|
|
|
# -*- coding: utf-8 -*-
|
|
|
# vim:fenc=utf-8
|
|
|
-import time
|
|
|
+import json
|
|
|
import logging
|
|
|
-import werkzeug.exceptions
|
|
|
-from flask import Flask, request, jsonify
|
|
|
from argparse import ArgumentParser
|
|
|
|
|
|
+import werkzeug.exceptions
|
|
|
+from flask import Flask, request, jsonify
|
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
|
|
+import pqai_agent_server.utils
|
|
|
+from pqai_agent import chat_service, prompt_templates
|
|
|
from pqai_agent import configs
|
|
|
-
|
|
|
-from pqai_agent import logging_service, chat_service, prompt_templates
|
|
|
-from pqai_agent.agents.message_reply_agent import MessageReplyAgent
|
|
|
+from pqai_agent.chat_service import OpenAICompatible
|
|
|
from pqai_agent.data_models.agent_configuration import AgentConfiguration
|
|
|
from pqai_agent.data_models.service_module import ServiceModule
|
|
|
from pqai_agent.history_dialogue_service import HistoryDialogueService
|
|
|
+from pqai_agent.logging import logger, setup_root_logger
|
|
|
+from pqai_agent.toolkit import global_tool_map
|
|
|
from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
|
|
|
from pqai_agent.utils.db_utils import create_ai_agent_db_engine
|
|
|
from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
|
|
|
+from pqai_agent_server.agent_task_server import AgentTaskManager
|
|
|
from pqai_agent_server.const import AgentApiConst
|
|
|
+from pqai_agent_server.const.status_enum import TestTaskStatus
|
|
|
+from pqai_agent_server.const.type_enum import EvaluateType
|
|
|
+from pqai_agent_server.dataset_service import DatasetService
|
|
|
from pqai_agent_server.models import MySQLSessionManager
|
|
|
-from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
|
|
|
+from pqai_agent_server.task_server import TaskManager
|
|
|
from pqai_agent_server.utils import (
|
|
|
run_extractor_prompt,
|
|
|
run_chat_prompt,
|
|
|
run_response_type_prompt,
|
|
|
)
|
|
|
+from pqai_agent_server.utils import wrap_response
|
|
|
|
|
|
app = Flask('agent_api_server')
|
|
|
-logger = logging_service.logger
|
|
|
const = AgentApiConst()
|
|
|
|
|
|
+
|
|
|
@app.route('/api/listStaffs', methods=['GET'])
|
|
|
def list_staffs():
|
|
|
staff_data = app.user_relation_manager.list_staffs()
|
|
@@ -91,34 +98,24 @@ def get_dialogue_history():
|
|
|
|
|
|
@app.route('/api/listModels', methods=['GET'])
|
|
|
def list_models():
|
|
|
- models = [
|
|
|
- {
|
|
|
- 'model_type': 'openai_compatible',
|
|
|
- 'model_name': chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
|
|
|
- 'display_name': 'DeepSeek V3 on 火山'
|
|
|
- },
|
|
|
- {
|
|
|
- 'model_type': 'openai_compatible',
|
|
|
- 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
|
|
|
- 'display_name': '豆包Pro 32K'
|
|
|
- },
|
|
|
- {
|
|
|
- 'model_type': 'openai_compatible',
|
|
|
- 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
|
|
|
- 'display_name': '豆包Pro 1.5'
|
|
|
- },
|
|
|
- {
|
|
|
- 'model_type': 'openai_compatible',
|
|
|
- 'model_name': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH,
|
|
|
- 'display_name': 'DeepSeek V3联网 on 火山'
|
|
|
- },
|
|
|
+ models = {
|
|
|
+ "deepseek-chat": chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
|
|
|
+ "gpt-4o": chat_service.OPENAI_MODEL_GPT_4o,
|
|
|
+ "gpt-4o-mini": chat_service.OPENAI_MODEL_GPT_4o_mini,
|
|
|
+ "doubao-pro-32k": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
|
|
|
+ "doubao-pro-1.5": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
|
|
|
+ "doubao-1.5-vision-pro": chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
|
|
|
+ "openrouter-gemini-2.5-pro": chat_service.OPENROUTER_MODEL_GEMINI_2_5_PRO,
|
|
|
+ }
|
|
|
+ ret_data = [
|
|
|
{
|
|
|
'model_type': 'openai_compatible',
|
|
|
- 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
|
|
|
- 'display_name': '豆包1.5视觉理解Pro'
|
|
|
- },
|
|
|
+ 'model_name': model_name,
|
|
|
+ 'display_name': f"{model_display_name} ({OpenAICompatible.get_price(model_name).get_cny_brief()})"
|
|
|
+ }
|
|
|
+ for model_display_name, model_name in models.items()
|
|
|
]
|
|
|
- return wrap_response(200, data=models)
|
|
|
+ return wrap_response(200, data=ret_data)
|
|
|
|
|
|
|
|
|
@app.route('/api/listScenes', methods=['GET'])
|
|
@@ -146,8 +143,8 @@ def get_base_prompt():
|
|
|
model_map = {
|
|
|
'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
|
|
|
'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
|
|
|
- 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
|
|
|
- 'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
|
|
|
+ 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
|
|
|
+ 'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
|
|
|
'custom_debugging': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH
|
|
|
}
|
|
|
if scene not in prompt_map:
|
|
@@ -178,6 +175,7 @@ def run_prompt():
|
|
|
logger.error(e)
|
|
|
return wrap_response(500, msg='Error: {}'.format(e))
|
|
|
|
|
|
+
|
|
|
@app.route('/api/formatForPrompt', methods=['POST'])
|
|
|
def format_data_for_prompt():
|
|
|
try:
|
|
@@ -297,8 +295,8 @@ def send_message():
|
|
|
return wrap_response(200, msg="暂不实现功能")
|
|
|
|
|
|
|
|
|
-@app.route("/api/quitHumanInterventionStatus", methods=["POST"])
|
|
|
-def quit_human_interventions_status():
|
|
|
+@app.route("/api/quitHumanIntervention", methods=["POST"])
|
|
|
+def quit_human_intervention():
|
|
|
"""
|
|
|
退出人工介入状态
|
|
|
:return:
|
|
@@ -308,9 +306,28 @@ def quit_human_interventions_status():
|
|
|
user_id = req_data["user_id"]
|
|
|
if not user_id or not staff_id:
|
|
|
return wrap_response(404, msg="user_id and staff_id are required")
|
|
|
- response = quit_human_intervention_status(user_id, staff_id)
|
|
|
+ if pqai_agent_server.utils.common.quit_human_intervention(user_id, staff_id):
|
|
|
+ return wrap_response(200, msg="success")
|
|
|
+ else:
|
|
|
+ return wrap_response(500, msg="error")
|
|
|
+
|
|
|
+
|
|
|
+@app.route("/api/enterHumanIntervention", methods=["POST"])
|
|
|
+def enter_human_intervention():
|
|
|
+ """
|
|
|
+ 进入人工介入状态
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ req_data = request.json
|
|
|
+ staff_id = req_data["staff_id"]
|
|
|
+ user_id = req_data["user_id"]
|
|
|
+ if not user_id or not staff_id:
|
|
|
+ return wrap_response(404, msg="user_id and staff_id are required")
|
|
|
+ if pqai_agent_server.utils.common.enter_human_intervention(user_id, staff_id):
|
|
|
+ return wrap_response(200, msg="success")
|
|
|
+ else:
|
|
|
+ return wrap_response(500, msg="error")
|
|
|
|
|
|
- return wrap_response(200, data=response)
|
|
|
|
|
|
## Agent管理接口
|
|
|
@app.route("/api/getNativeAgentList", methods=["GET"])
|
|
@@ -332,22 +349,29 @@ def get_native_agent_list():
|
|
|
query = query.filter(AgentConfiguration.create_user == create_user)
|
|
|
if update_user:
|
|
|
query = query.filter(AgentConfiguration.update_user == update_user)
|
|
|
+ total = query.count()
|
|
|
query = query.offset(offset).limit(int(page_size))
|
|
|
data = query.all()
|
|
|
- ret_data = [
|
|
|
- {
|
|
|
- 'id': agent.id,
|
|
|
- 'name': agent.name,
|
|
|
- 'display_name': agent.display_name,
|
|
|
- 'type': agent.type,
|
|
|
- 'execution_model': agent.execution_model,
|
|
|
- 'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
|
|
|
- 'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
|
|
|
- }
|
|
|
- for agent in data
|
|
|
- ]
|
|
|
+ ret_data = {
|
|
|
+ 'total': total,
|
|
|
+ 'agent_list': [
|
|
|
+ {
|
|
|
+ 'id': agent.id,
|
|
|
+ 'name': agent.name,
|
|
|
+ 'display_name': agent.display_name,
|
|
|
+ 'type': agent.type,
|
|
|
+ 'execution_model': agent.execution_model,
|
|
|
+ 'create_user': agent.create_user,
|
|
|
+ 'update_user': agent.update_user,
|
|
|
+ 'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
|
|
|
+ 'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
|
|
|
+ }
|
|
|
+ for agent in data
|
|
|
+ ]
|
|
|
+ }
|
|
|
return wrap_response(200, data=ret_data)
|
|
|
|
|
|
+
|
|
|
@app.route("/api/getNativeAgentConfiguration", methods=["GET"])
|
|
|
def get_native_agent_configuration():
|
|
|
"""
|
|
@@ -371,14 +395,15 @@ def get_native_agent_configuration():
|
|
|
'execution_model': agent.execution_model,
|
|
|
'system_prompt': agent.system_prompt,
|
|
|
'task_prompt': agent.task_prompt,
|
|
|
- 'tools': agent.tools,
|
|
|
- 'sub_agents': agent.sub_agents,
|
|
|
- 'extra_params': agent.extra_params,
|
|
|
+ 'tools': json.loads(agent.tools),
|
|
|
+ 'sub_agents': json.loads(agent.sub_agents),
|
|
|
+ 'extra_params': json.loads(agent.extra_params),
|
|
|
'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
|
|
|
'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
|
|
|
}
|
|
|
return wrap_response(200, data=data)
|
|
|
|
|
|
+
|
|
|
@app.route("/api/saveNativeAgentConfiguration", methods=["POST"])
|
|
|
def save_native_agent_configuration():
|
|
|
"""
|
|
@@ -393,9 +418,19 @@ def save_native_agent_configuration():
|
|
|
execution_model = req_data.get('execution_model', None)
|
|
|
system_prompt = req_data.get('system_prompt', None)
|
|
|
task_prompt = req_data.get('task_prompt', None)
|
|
|
- tools = req_data.get('tools', [])
|
|
|
- sub_agents = req_data.get('sub_agents', [])
|
|
|
+ tools = json.dumps(req_data.get('tools', []))
|
|
|
+ sub_agents = json.dumps(req_data.get('sub_agents', []))
|
|
|
extra_params = req_data.get('extra_params', {})
|
|
|
+ operate_user = req_data.get('user', None)
|
|
|
+ if isinstance(extra_params, dict):
|
|
|
+ extra_params = json.dumps(extra_params)
|
|
|
+ elif isinstance(extra_params, str):
|
|
|
+ try:
|
|
|
+ json.loads(extra_params)
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ return wrap_response(400, msg='extra_params should be a valid JSON object or string')
|
|
|
+ if not extra_params:
|
|
|
+ extra_params = '{}'
|
|
|
|
|
|
if not name:
|
|
|
return wrap_response(400, msg='name is required')
|
|
@@ -415,6 +450,7 @@ def save_native_agent_configuration():
|
|
|
agent.tools = tools
|
|
|
agent.sub_agents = sub_agents
|
|
|
agent.extra_params = extra_params
|
|
|
+ agent.update_user = operate_user
|
|
|
else:
|
|
|
agent = AgentConfiguration(
|
|
|
name=name,
|
|
@@ -425,37 +461,88 @@ def save_native_agent_configuration():
|
|
|
task_prompt=task_prompt,
|
|
|
tools=tools,
|
|
|
sub_agents=sub_agents,
|
|
|
- extra_params=extra_params
|
|
|
+ extra_params=extra_params,
|
|
|
+ create_user=operate_user,
|
|
|
+ update_user=operate_user
|
|
|
)
|
|
|
session.add(agent)
|
|
|
|
|
|
session.commit()
|
|
|
return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
|
|
|
|
|
|
+
|
|
|
+@app.route("/api/deleteNativeAgentConfiguration", methods=["POST"])
|
|
|
+def delete_native_agent_configuration():
|
|
|
+ """
|
|
|
+ 删除指定Agent配置(软删除,设置is_delete=1)
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ req_data = request.json
|
|
|
+ agent_id = req_data.get('agent_id', None)
|
|
|
+ if not agent_id:
|
|
|
+ return wrap_response(400, msg='agent_id is required')
|
|
|
+ try:
|
|
|
+ agent_id = int(agent_id)
|
|
|
+ except ValueError:
|
|
|
+ return wrap_response(400, msg='agent_id must be an integer')
|
|
|
+
|
|
|
+ with app.session_maker() as session:
|
|
|
+ agent = session.query(AgentConfiguration).filter(
|
|
|
+ AgentConfiguration.id == agent_id,
|
|
|
+ AgentConfiguration.is_delete == 0
|
|
|
+ ).first()
|
|
|
+ if not agent:
|
|
|
+ return wrap_response(404, msg='Agent not found')
|
|
|
+ agent.is_delete = 1
|
|
|
+ session.commit()
|
|
|
+ return wrap_response(200, msg='Agent configuration deleted successfully')
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
@app.route("/api/getModuleList", methods=["GET"])
|
|
|
def get_module_list():
|
|
|
"""
|
|
|
- 获取所有的模块列表
|
|
|
+ 获取所有的模块列表,支持分页查询
|
|
|
:return:
|
|
|
"""
|
|
|
+ page = request.args.get('page', 1)
|
|
|
+ page_size = request.args.get('page_size', 50)
|
|
|
+ try:
|
|
|
+ page = int(page)
|
|
|
+ page_size = int(page_size)
|
|
|
+ except Exception as e:
|
|
|
+ return wrap_response(400, msg="Invalid parameter: {}".format(e))
|
|
|
+
|
|
|
+ offset = (page - 1) * page_size
|
|
|
with app.session_maker() as session:
|
|
|
- query = session.query(ServiceModule) \
|
|
|
- .filter(ServiceModule.is_delete == 0)
|
|
|
- data = query.all()
|
|
|
- ret_data = [
|
|
|
- {
|
|
|
- 'id': module.id,
|
|
|
- 'name': module.name,
|
|
|
- 'display_name': module.display_name,
|
|
|
- 'default_agent_type': module.default_agent_type,
|
|
|
- 'default_agent_id': module.default_agent_id,
|
|
|
- 'create_time': module.create_time.strftime('%Y-%m-%d %H:%M:%S'),
|
|
|
- 'update_time': module.update_time.strftime('%Y-%m-%d %H:%M:%S')
|
|
|
- }
|
|
|
- for module in data
|
|
|
- ]
|
|
|
+ query = session.query(
|
|
|
+ ServiceModule,
|
|
|
+ AgentConfiguration.name.label("default_agent_name")
|
|
|
+ ).outerjoin(
|
|
|
+ AgentConfiguration,
|
|
|
+ ServiceModule.default_agent_id == AgentConfiguration.id
|
|
|
+ ).filter(ServiceModule.is_delete == 0)
|
|
|
+ total = query.count()
|
|
|
+ modules = query.offset(offset).limit(page_size).all()
|
|
|
+ ret_data = {
|
|
|
+ 'total': total,
|
|
|
+ 'module_list': [
|
|
|
+ {
|
|
|
+ 'id': module.id,
|
|
|
+ 'name': module.name,
|
|
|
+ 'display_name': module.display_name,
|
|
|
+ 'default_agent_type': module.default_agent_type,
|
|
|
+ 'default_agent_id': module.default_agent_id,
|
|
|
+ 'default_agent_name': default_agent_name,
|
|
|
+ 'create_time': module.create_time.strftime('%Y-%m-%d %H:%M:%S'),
|
|
|
+ 'update_time': module.update_time.strftime('%Y-%m-%d %H:%M:%S')
|
|
|
+ }
|
|
|
+ for module, default_agent_name in modules
|
|
|
+ ]
|
|
|
+ }
|
|
|
return wrap_response(200, data=ret_data)
|
|
|
|
|
|
+
|
|
|
@app.route("/api/getModuleConfiguration", methods=["GET"])
|
|
|
def get_module_configuration():
|
|
|
"""
|
|
@@ -478,10 +565,11 @@ def get_module_configuration():
|
|
|
'default_agent_type': module.default_agent_type,
|
|
|
'default_agent_id': module.default_agent_id,
|
|
|
'create_time': module.create_time.strftime('%Y-%m-%d %H:%M:%S'),
|
|
|
- 'updated_time': module.updated_time.strftime('%Y-%m-%d %H:%M:%S')
|
|
|
+ 'update_time': module.update_time.strftime('%Y-%m-%d %H:%M:%S')
|
|
|
}
|
|
|
return wrap_response(200, data=data)
|
|
|
|
|
|
+
|
|
|
@app.route("/api/saveModuleConfiguration", methods=["POST"])
|
|
|
def save_module_configuration():
|
|
|
"""
|
|
@@ -520,6 +608,248 @@ def save_module_configuration():
|
|
|
session.commit()
|
|
|
return wrap_response(200, msg='Module configuration saved successfully', data={'id': module.id})
|
|
|
|
|
|
+@app.route("/api/deleteModuleConfiguration", methods=["POST"])
|
|
|
+def delete_module_configuration():
|
|
|
+ """
|
|
|
+ 删除指定模块配置(软删除,设置is_delete=1)
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ req_data = request.json
|
|
|
+ module_id = req_data.get('module_id', None)
|
|
|
+ if not module_id:
|
|
|
+ return wrap_response(400, msg='module_id is required')
|
|
|
+ try:
|
|
|
+ module_id = int(module_id)
|
|
|
+ except ValueError:
|
|
|
+ return wrap_response(400, msg='module_id must be an integer')
|
|
|
+
|
|
|
+ with app.session_maker() as session:
|
|
|
+ module = session.query(ServiceModule).filter(
|
|
|
+ ServiceModule.id == module_id,
|
|
|
+ ServiceModule.is_delete == 0
|
|
|
+ ).first()
|
|
|
+ if not module:
|
|
|
+ return wrap_response(404, msg='Module not found')
|
|
|
+ module.is_delete = 1
|
|
|
+ session.commit()
|
|
|
+ return wrap_response(200, msg='Module configuration deleted successfully')
|
|
|
+
|
|
|
+@app.route("/api/getToolList", methods=["GET"])
|
|
|
+def get_tool_list():
|
|
|
+ """
|
|
|
+ 获取所有的工具列表
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ tools = []
|
|
|
+ for tool_name, tool in global_tool_map.items():
|
|
|
+ tools.append({
|
|
|
+ 'name': tool_name,
|
|
|
+ 'description': tool.get_function_description(),
|
|
|
+ 'parameters': tool.parameters if hasattr(tool, 'parameters') else {}
|
|
|
+ })
|
|
|
+ return wrap_response(200, data=tools)
|
|
|
+
|
|
|
+@app.route("/api/getModuleAgentTypes", methods=["GET"])
|
|
|
+def get_agent_types():
|
|
|
+ """
|
|
|
+ 获取所有的Agent类型
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ agent_types = [
|
|
|
+ {'type': 0, 'display_name': '原生'},
|
|
|
+ {'type': 1, 'display_name': 'Coze'}
|
|
|
+ ]
|
|
|
+ return wrap_response(200, data=agent_types)
|
|
|
+
|
|
|
+@app.route("/api/getTestTaskList", methods=["GET"])
|
|
|
+def get_test_task_list():
|
|
|
+ """
|
|
|
+ 获取单元测试任务列表
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
|
|
|
+ page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
|
|
|
+ try:
|
|
|
+ page_num = int(page_num)
|
|
|
+ page_size = int(page_size)
|
|
|
+ except Exception as e:
|
|
|
+ return wrap_response(404, msg="Invalid parameter: {}".format(e))
|
|
|
+ response = app.task_manager.get_test_task_list(page_num, page_size)
|
|
|
+ return wrap_response(200, data=response)
|
|
|
+
|
|
|
+
|
|
|
+@app.route("/api/getTestTaskConversations", methods=["GET"])
|
|
|
+def get_test_task_conversations():
|
|
|
+ """
|
|
|
+ 获取单元测试对话任务列表
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ task_id = request.args.get("taskId", None)
|
|
|
+ if not task_id:
|
|
|
+ return wrap_response(404, msg='task_id is required')
|
|
|
+ page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
|
|
|
+ page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
|
|
|
+ try:
|
|
|
+ page_num = int(page_num)
|
|
|
+ page_size = int(page_size)
|
|
|
+ except Exception as e:
|
|
|
+ return wrap_response(404, msg="Invalid parameter: {}".format(e))
|
|
|
+ response = app.task_manager.get_test_task_conversations(int(task_id), page_num, page_size)
|
|
|
+ return wrap_response(200, data=response)
|
|
|
+
|
|
|
+
|
|
|
+@app.route("/api/createTestTask", methods=["POST"])
|
|
|
+def create_test_task():
|
|
|
+ """
|
|
|
+ 创建单元测试任务
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ req_data = request.json
|
|
|
+ agent_id = req_data.get('agentId', None)
|
|
|
+ module_id = req_data.get('moduleId', None)
|
|
|
+ evaluate_type = req_data.get('evaluateType', None)
|
|
|
+ if not agent_id:
|
|
|
+ return wrap_response(404, msg='agent id is required')
|
|
|
+ if not module_id:
|
|
|
+ return wrap_response(404, msg='module id is required')
|
|
|
+ if not evaluate_type:
|
|
|
+ return wrap_response(404, msg='evaluate_type id is required')
|
|
|
+ app.task_manager.create_task(agent_id, module_id, evaluate_type)
|
|
|
+ return wrap_response(200)
|
|
|
+
|
|
|
+
|
|
|
+@app.route("/api/stopTestTask", methods=["POST"])
|
|
|
+def stop_test_task():
|
|
|
+ """
|
|
|
+ 停止单元测试任务
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ req_data = request.json
|
|
|
+ task_id = req_data.get('taskId', None)
|
|
|
+ if not task_id:
|
|
|
+ return wrap_response(400, msg='task id is required')
|
|
|
+ task = app.task_manager.get_task(task_id)
|
|
|
+ if task.status not in (TestTaskStatus.NOT_STARTED.value, TestTaskStatus.IN_PROGRESS.value):
|
|
|
+ return wrap_response(400, msg='task status is invalid')
|
|
|
+ app.task_manager.cancel_task(task_id)
|
|
|
+ return wrap_response(200)
|
|
|
+
|
|
|
+
|
|
|
+@app.route("/api/resumeTestTask", methods=["POST"])
|
|
|
+def resume_test_task():
|
|
|
+ """
|
|
|
+ 恢复停止的单元测试任务
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ req_data = request.json
|
|
|
+ task_id = req_data.get('taskId', None)
|
|
|
+ if not task_id:
|
|
|
+ return wrap_response(400, msg='task id is required')
|
|
|
+ task = app.task_manager.get_task(task_id)
|
|
|
+ if task.status != TestTaskStatus.CANCELLED.value:
|
|
|
+ return wrap_response(400, msg='task status is invalid')
|
|
|
+ app.task_manager.resume_task(task_id)
|
|
|
+ return wrap_response(200)
|
|
|
+
|
|
|
+
|
|
|
+@app.route("/api/getEvaluateType", methods=["GET"])
|
|
|
+def get_evaluate_type():
|
|
|
+ """
|
|
|
+ 获取评估类型
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ name_desc_list = [
|
|
|
+ {
|
|
|
+ "type": item.value,
|
|
|
+ "desc": item.description
|
|
|
+ }
|
|
|
+ for item in EvaluateType]
|
|
|
+ return wrap_response(code=200, data=name_desc_list)
|
|
|
+
|
|
|
+
|
|
|
+@app.route("/api/getDatasetList", methods=["GET"])
|
|
|
+def get_dataset_list():
|
|
|
+ """
|
|
|
+ 获取数据集列表
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
|
|
|
+ page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
|
|
|
+ try:
|
|
|
+ page_num = int(page_num)
|
|
|
+ page_size = int(page_size)
|
|
|
+ except Exception as e:
|
|
|
+ return wrap_response(404, msg="Invalid parameter: {}".format(e))
|
|
|
+ response = app.dataset_service.get_dataset_list(page_num, page_size)
|
|
|
+ return wrap_response(200, data=response)
|
|
|
+
|
|
|
+
|
|
|
+@app.route("/api/getConversationDataList", methods=["GET"])
|
|
|
+def get_conversation_data_list():
|
|
|
+ """
|
|
|
+ 获取对话列表
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ dataset_id = request.args.get("datasetId", None)
|
|
|
+ if not dataset_id:
|
|
|
+ return wrap_response(404, msg='dataset_id is required')
|
|
|
+ page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
|
|
|
+ page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
|
|
|
+ try:
|
|
|
+ page_num = int(page_num)
|
|
|
+ page_size = int(page_size)
|
|
|
+ except Exception as e:
|
|
|
+ return wrap_response(404, msg="Invalid parameter: {}".format(e))
|
|
|
+ response = app.dataset_service.get_conversation_data_list(int(dataset_id), page_num, page_size)
|
|
|
+ return wrap_response(200, data=response)
|
|
|
+
|
|
|
+
|
|
|
+@app.route("/api/createAgentTask", methods=["POST"])
|
|
|
+def create_agent_task():
|
|
|
+ """
|
|
|
+ 创建agent执行任务
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ req_data = request.json
|
|
|
+ agent_id = req_data.get('agentId', None)
|
|
|
+ task_prompt = req_data.get('taskPrompt', None)
|
|
|
+ if not agent_id:
|
|
|
+ return wrap_response(404, msg='agent id is required')
|
|
|
+ if not task_prompt:
|
|
|
+ return wrap_response(404, msg='task_prompt is required')
|
|
|
+ app.agent_task_manager.create_task(agent_id, task_prompt)
|
|
|
+ return wrap_response(200)
|
|
|
+
|
|
|
+@app.route("/api/getAgentTaskList", methods=["GET"])
|
|
|
+def get_agent_task_list():
|
|
|
+ """
|
|
|
+ 获取单元测试任务列表
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ page_num = request.args.get("pageNum", const.DEFAULT_PAGE_ID)
|
|
|
+ page_size = request.args.get("pageSize", const.DEFAULT_PAGE_SIZE)
|
|
|
+ try:
|
|
|
+ page_num = int(page_num)
|
|
|
+ page_size = int(page_size)
|
|
|
+ except Exception as e:
|
|
|
+ return wrap_response(404, msg="Invalid parameter: {}".format(e))
|
|
|
+ response = app.agent_task_manager.get_agent_task_list(page_num, page_size)
|
|
|
+ return wrap_response(200, data=response)
|
|
|
+
|
|
|
+
|
|
|
+@app.route("/api/getAgentTaskDetail", methods=["GET"])
|
|
|
+def get_agent_task_detail():
|
|
|
+ """
|
|
|
+ 查询agent执行任务详情
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ agent_task_id = request.args.get("agentTaskId", None)
|
|
|
+ if not agent_task_id:
|
|
|
+ return wrap_response(404, msg='agent_task_id is required')
|
|
|
+ parent_execution_id = request.args.get("parentExecutionId", None)
|
|
|
+ response = app.agent_task_manager.get_agent_task_detail(int(agent_task_id), parent_execution_id)
|
|
|
+ return wrap_response(200, data=response)
|
|
|
+
|
|
|
@app.errorhandler(werkzeug.exceptions.BadRequest)
|
|
|
def handle_bad_request(e):
|
|
|
logger.error(e)
|
|
@@ -536,7 +866,7 @@ if __name__ == '__main__':
|
|
|
|
|
|
config = configs.get()
|
|
|
logging_level = logging.getLevelName(args.log_level)
|
|
|
- logging_service.setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
|
|
|
+ setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
|
|
|
|
|
|
# set db config
|
|
|
agent_db_config = config['database']['ai_agent']
|
|
@@ -547,7 +877,7 @@ if __name__ == '__main__':
|
|
|
chat_history_db_config = config['storage']['chat_history']
|
|
|
|
|
|
# init user manager
|
|
|
- user_manager = MySQLUserManager(agent_db_config, growth_db_config, staff_db_config['table'])
|
|
|
+ user_manager = MySQLUserManager(agent_db_config, user_db_config['table'], staff_db_config['table'])
|
|
|
app.user_manager = user_manager
|
|
|
|
|
|
# init session manager
|
|
@@ -559,9 +889,20 @@ if __name__ == '__main__':
|
|
|
chat_history_table=chat_history_db_config['table']
|
|
|
)
|
|
|
app.session_manager = session_manager
|
|
|
- agent_db_engine = create_ai_agent_db_engine(config['database']['ai_agent'])
|
|
|
+ agent_db_engine = create_ai_agent_db_engine()
|
|
|
app.session_maker = sessionmaker(bind=agent_db_engine)
|
|
|
|
|
|
+ dataset_service = DatasetService(session_maker=sessionmaker(bind=agent_db_engine))
|
|
|
+ app.dataset_service = dataset_service
|
|
|
+
|
|
|
+ task_manager = TaskManager(session_maker=sessionmaker(bind=agent_db_engine), dataset_service=dataset_service)
|
|
|
+ app.task_manager = task_manager
|
|
|
+ app.task_manager.recover_tasks()
|
|
|
+
|
|
|
+ agent_task_manager = AgentTaskManager(session_maker=sessionmaker(bind=agent_db_engine))
|
|
|
+ app.agent_task_manager = agent_task_manager
|
|
|
+ app.agent_task_manager.recover_tasks()
|
|
|
+
|
|
|
wecom_db_config = config['storage']['user_relation']
|
|
|
user_relation_manager = MySQLUserRelationManager(
|
|
|
agent_db_config, growth_db_config,
|