|
@@ -7,6 +7,7 @@ import werkzeug.exceptions
|
|
|
from flask import Flask, request, jsonify
|
|
|
from argparse import ArgumentParser
|
|
|
|
|
|
+from pyarrow.dataset import dataset
|
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
|
|
from pqai_agent import configs
|
|
@@ -21,6 +22,7 @@ from pqai_agent.utils.db_utils import create_sql_engine
|
|
|
from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
|
|
|
from pqai_agent_server.const import AgentApiConst
|
|
|
from pqai_agent_server.const.status_enum import TestTaskStatus
|
|
|
+from pqai_agent_server.dataset_server import DatasetServer
|
|
|
from pqai_agent_server.models import MySQLSessionManager
|
|
|
from pqai_agent_server.task_server import TaskManager
|
|
|
from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
|
|
@@ -34,6 +36,7 @@ app = Flask('agent_api_server')
|
|
|
logger = logging_service.logger
|
|
|
const = AgentApiConst()
|
|
|
|
|
|
+
|
|
|
@app.route('/api/listStaffs', methods=['GET'])
|
|
|
def list_staffs():
|
|
|
staff_data = app.user_relation_manager.list_staffs()
|
|
@@ -180,6 +183,7 @@ def run_prompt():
|
|
|
logger.error(e)
|
|
|
return wrap_response(500, msg='Error: {}'.format(e))
|
|
|
|
|
|
+
|
|
|
@app.route('/api/formatForPrompt', methods=['POST'])
|
|
|
def format_data_for_prompt():
|
|
|
try:
|
|
@@ -314,6 +318,7 @@ def quit_human_interventions_status():
|
|
|
|
|
|
return wrap_response(200, data=response)
|
|
|
|
|
|
+
|
|
|
## Agent管理接口
|
|
|
@app.route("/api/getNativeAgentList", methods=["GET"])
|
|
|
def get_native_agent_list():
|
|
@@ -350,6 +355,7 @@ def get_native_agent_list():
|
|
|
]
|
|
|
return wrap_response(200, data=ret_data)
|
|
|
|
|
|
+
|
|
|
@app.route("/api/getNativeAgentConfiguration", methods=["GET"])
|
|
|
def get_native_agent_configuration():
|
|
|
"""
|
|
@@ -381,6 +387,7 @@ def get_native_agent_configuration():
|
|
|
}
|
|
|
return wrap_response(200, data=data)
|
|
|
|
|
|
+
|
|
|
@app.route("/api/saveNativeAgentConfiguration", methods=["POST"])
|
|
|
def save_native_agent_configuration():
|
|
|
"""
|
|
@@ -434,6 +441,7 @@ def save_native_agent_configuration():
|
|
|
session.commit()
|
|
|
return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
|
|
|
|
|
|
+
|
|
|
@app.route("/api/getModuleList", methods=["GET"])
|
|
|
def get_module_list():
|
|
|
"""
|
|
@@ -458,6 +466,7 @@ def get_module_list():
|
|
|
]
|
|
|
return wrap_response(200, data=ret_data)
|
|
|
|
|
|
+
|
|
|
@app.route("/api/getModuleConfiguration", methods=["GET"])
|
|
|
def get_module_configuration():
|
|
|
"""
|
|
@@ -484,6 +493,7 @@ def get_module_configuration():
|
|
|
}
|
|
|
return wrap_response(200, data=data)
|
|
|
|
|
|
+
|
|
|
@app.route("/api/saveModuleConfiguration", methods=["POST"])
|
|
|
def save_module_configuration():
|
|
|
"""
|
|
@@ -568,12 +578,12 @@ def create_test_task():
|
|
|
"""
|
|
|
req_data = request.json
|
|
|
agent_id = req_data.get('agentId', None)
|
|
|
- model_id = req_data.get('modelId', None)
|
|
|
+ module_id = req_data.get('moduleId', None)
|
|
|
if not agent_id:
|
|
|
return wrap_response(404, msg='agent id is required')
|
|
|
- if not model_id:
|
|
|
- return wrap_response(404, msg='model id is required')
|
|
|
- app.task_manager.create_task(agent_id, model_id)
|
|
|
+ if not module_id:
|
|
|
+ return wrap_response(404, msg='module id is required')
|
|
|
+ app.task_manager.create_task(agent_id, module_id)
|
|
|
return wrap_response(200)
|
|
|
|
|
|
|
|
@@ -651,7 +661,8 @@ if __name__ == '__main__':
|
|
|
agent_db_engine = create_sql_engine(config['storage']['agent_state']['mysql'])
|
|
|
app.session_maker = sessionmaker(bind=agent_db_engine)
|
|
|
|
|
|
- task_manager = TaskManager(session_maker=sessionmaker(bind=agent_db_engine))
|
|
|
+ dataset_server = DatasetServer(session_maker=sessionmaker(bind=agent_db_engine))
|
|
|
+ task_manager = TaskManager(session_maker=sessionmaker(bind=agent_db_engine), dataset_server=dataset_server)
|
|
|
app.task_manager = task_manager
|
|
|
|
|
|
wecom_db_config = config['storage']['user_relation']
|