|
@@ -8,6 +8,7 @@ import werkzeug.exceptions
|
|
|
from flask import Flask, request, jsonify
|
|
|
from argparse import ArgumentParser
|
|
|
|
|
|
+from pyarrow.dataset import dataset
|
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
|
|
from pqai_agent import configs
|
|
@@ -24,6 +25,7 @@ from pqai_agent.utils.db_utils import create_ai_agent_db_engine
|
|
|
from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
|
|
|
from pqai_agent_server.const import AgentApiConst
|
|
|
from pqai_agent_server.const.status_enum import TestTaskStatus
|
|
|
+from pqai_agent_server.dataset_server import DatasetServer
|
|
|
from pqai_agent_server.models import MySQLSessionManager
|
|
|
import pqai_agent_server.utils
|
|
|
from pqai_agent_server.task_server import TaskManager
|
|
@@ -37,6 +39,7 @@ from pqai_agent_server.utils import (
|
|
|
app = Flask('agent_api_server')
|
|
|
const = AgentApiConst()
|
|
|
|
|
|
+
|
|
|
@app.route('/api/listStaffs', methods=['GET'])
|
|
|
def list_staffs():
|
|
|
staff_data = app.user_relation_manager.list_staffs()
|
|
@@ -173,6 +176,7 @@ def run_prompt():
|
|
|
logger.error(e)
|
|
|
return wrap_response(500, msg='Error: {}'.format(e))
|
|
|
|
|
|
+
|
|
|
@app.route('/api/formatForPrompt', methods=['POST'])
|
|
|
def format_data_for_prompt():
|
|
|
try:
|
|
@@ -325,6 +329,7 @@ def enter_human_intervention():
|
|
|
else:
|
|
|
return wrap_response(500, msg="error")
|
|
|
|
|
|
+
|
|
|
## Agent管理接口
|
|
|
@app.route("/api/getNativeAgentList", methods=["GET"])
|
|
|
def get_native_agent_list():
|
|
@@ -367,6 +372,7 @@ def get_native_agent_list():
|
|
|
}
|
|
|
return wrap_response(200, data=ret_data)
|
|
|
|
|
|
+
|
|
|
@app.route("/api/getNativeAgentConfiguration", methods=["GET"])
|
|
|
def get_native_agent_configuration():
|
|
|
"""
|
|
@@ -398,6 +404,7 @@ def get_native_agent_configuration():
|
|
|
}
|
|
|
return wrap_response(200, data=data)
|
|
|
|
|
|
+
|
|
|
@app.route("/api/saveNativeAgentConfiguration", methods=["POST"])
|
|
|
def save_native_agent_configuration():
|
|
|
"""
|
|
@@ -536,6 +543,7 @@ def get_module_list():
|
|
|
}
|
|
|
return wrap_response(200, data=ret_data)
|
|
|
|
|
|
+
|
|
|
@app.route("/api/getModuleConfiguration", methods=["GET"])
|
|
|
def get_module_configuration():
|
|
|
"""
|
|
@@ -562,6 +570,7 @@ def get_module_configuration():
|
|
|
}
|
|
|
return wrap_response(200, data=data)
|
|
|
|
|
|
+
|
|
|
@app.route("/api/saveModuleConfiguration", methods=["POST"])
|
|
|
def save_module_configuration():
|
|
|
"""
|
|
@@ -672,12 +681,12 @@ def create_test_task():
|
|
|
"""
|
|
|
req_data = request.json
|
|
|
agent_id = req_data.get('agentId', None)
|
|
|
- model_id = req_data.get('modelId', None)
|
|
|
+ module_id = req_data.get('moduleId', None)
|
|
|
if not agent_id:
|
|
|
return wrap_response(404, msg='agent id is required')
|
|
|
- if not model_id:
|
|
|
- return wrap_response(404, msg='model id is required')
|
|
|
- app.task_manager.create_task(agent_id, model_id)
|
|
|
+ if not module_id:
|
|
|
+ return wrap_response(404, msg='module id is required')
|
|
|
+ app.task_manager.create_task(agent_id, module_id)
|
|
|
return wrap_response(200)
|
|
|
|
|
|
|
|
@@ -757,7 +766,8 @@ if __name__ == '__main__':
|
|
|
agent_db_engine = create_ai_agent_db_engine()
|
|
|
app.session_maker = sessionmaker(bind=agent_db_engine)
|
|
|
|
|
|
- task_manager = TaskManager(session_maker=sessionmaker(bind=agent_db_engine))
|
|
|
+ dataset_server = DatasetServer(session_maker=sessionmaker(bind=agent_db_engine))
|
|
|
+ task_manager = TaskManager(session_maker=sessionmaker(bind=agent_db_engine), dataset_server=dataset_server)
|
|
|
app.task_manager = task_manager
|
|
|
|
|
|
wecom_db_config = config['storage']['user_relation']
|