|
@@ -7,12 +7,17 @@ import werkzeug.exceptions
|
|
|
from flask import Flask, request, jsonify
|
|
|
from argparse import ArgumentParser
|
|
|
|
|
|
+from sqlalchemy.orm import sessionmaker
|
|
|
+
|
|
|
from pqai_agent import configs
|
|
|
|
|
|
from pqai_agent import logging_service, chat_service, prompt_templates
|
|
|
from pqai_agent.agents.message_reply_agent import MessageReplyAgent
|
|
|
+from pqai_agent.data_models.agent_configuration import AgentConfiguration
|
|
|
+from pqai_agent.data_models.service_module import ServiceModule
|
|
|
from pqai_agent.history_dialogue_service import HistoryDialogueService
|
|
|
from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
|
|
|
+from pqai_agent.utils.db_utils import create_sql_engine
|
|
|
from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
|
|
|
from pqai_agent_server.const import AgentApiConst
|
|
|
from pqai_agent_server.models import MySQLSessionManager
|
|
@@ -307,6 +312,213 @@ def quit_human_interventions_status():
|
|
|
|
|
|
return wrap_response(200, data=response)
|
|
|
|
|
|
+## Agent管理接口
|
|
|
+@app.route("/api/getNativeAgentList", methods=["GET"])
|
|
|
+def get_native_agent_list():
|
|
|
+ """
|
|
|
+ 获取所有的Agent列表
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ page = request.args.get('page', 1)
|
|
|
+ page_size = request.args.get('page_size', 20)
|
|
|
+ create_user = request.args.get('create_user', None)
|
|
|
+ update_user = request.args.get('update_user', None)
|
|
|
+
|
|
|
+ offset = (int(page) - 1) * int(page_size)
|
|
|
+ with app.session_maker() as session:
|
|
|
+ query = session.query(AgentConfiguration) \
|
|
|
+ .filter(AgentConfiguration.is_delete == 0) \
|
|
|
+ .offset(offset).limit(int(page_size))
|
|
|
+ if create_user:
|
|
|
+ query = query.filter(AgentConfiguration.create_user == create_user)
|
|
|
+ if update_user:
|
|
|
+ query = query.filter(AgentConfiguration.update_user == update_user)
|
|
|
+ data = query.all()
|
|
|
+ ret_data = [
|
|
|
+ {
|
|
|
+ 'id': agent.id,
|
|
|
+ 'name': agent.name,
|
|
|
+ 'display_name': agent.display_name,
|
|
|
+ 'type': agent.type,
|
|
|
+ 'execution_model': agent.execution_model,
|
|
|
+ 'created_at': agent.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
|
|
+ 'updated_at': agent.updated_at.strftime('%Y-%m-%d %H:%M:%S')
|
|
|
+ }
|
|
|
+ for agent in data
|
|
|
+ ]
|
|
|
+ return wrap_response(200, data=ret_data)
|
|
|
+
|
|
|
+@app.route("/api/getAgentConfiguration", methods=["GET"])
|
|
|
+def get_agent_configuration():
|
|
|
+ """
|
|
|
+ 获取指定Agent的配置
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ agent_id = request.args.get('agent_id')
|
|
|
+ if not agent_id:
|
|
|
+ return wrap_response(404, msg='agent_id is required')
|
|
|
+
|
|
|
+ with app.session_maker() as session:
|
|
|
+ agent = session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).first()
|
|
|
+ if not agent:
|
|
|
+ return wrap_response(404, msg='Agent not found')
|
|
|
+
|
|
|
+ data = {
|
|
|
+ 'id': agent.id,
|
|
|
+ 'name': agent.name,
|
|
|
+ 'display_name': agent.display_name,
|
|
|
+ 'type': agent.type,
|
|
|
+ 'execution_model': agent.execution_model,
|
|
|
+ 'system_prompt': agent.system_prompt,
|
|
|
+ 'task_prompt': agent.task_prompt,
|
|
|
+ 'tools': agent.tools,
|
|
|
+ 'sub_agents': agent.sub_agents,
|
|
|
+ 'extra_params': agent.extra_params,
|
|
|
+ 'created_at': agent.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
|
|
+ 'updated_at': agent.updated_at.strftime('%Y-%m-%d %H:%M:%S')
|
|
|
+ }
|
|
|
+ return wrap_response(200, data=data)
|
|
|
+
|
|
|
+@app.route("/api/saveAgentConfiguration", methods=["POST"])
|
|
|
+def save_agent_configuration():
|
|
|
+ """
|
|
|
+ 保存Agent配置
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ req_data = request.json
|
|
|
+ agent_id = req_data.get('agent_id', None)
|
|
|
+ name = req_data.get('name')
|
|
|
+ display_name = req_data.get('display_name', None)
|
|
|
+ type_ = req_data.get('type', 0)
|
|
|
+ execution_model = req_data.get('execution_model', None)
|
|
|
+ system_prompt = req_data.get('system_prompt', None)
|
|
|
+ task_prompt = req_data.get('task_prompt', None)
|
|
|
+ tools = req_data.get('tools', [])
|
|
|
+ sub_agents = req_data.get('sub_agents', [])
|
|
|
+ extra_params = req_data.get('extra_params', {})
|
|
|
+
|
|
|
+ if not name:
|
|
|
+ return wrap_response(400, msg='name is required')
|
|
|
+
|
|
|
+ with app.session_maker() as session:
|
|
|
+ if agent_id:
|
|
|
+ agent_id = int(agent_id)
|
|
|
+ agent = session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).first()
|
|
|
+ if not agent:
|
|
|
+ return wrap_response(404, msg='Agent not found')
|
|
|
+ agent.name = name
|
|
|
+ agent.display_name = display_name
|
|
|
+ agent.type = type_
|
|
|
+ agent.execution_model = execution_model
|
|
|
+ agent.system_prompt = system_prompt
|
|
|
+ agent.task_prompt = task_prompt
|
|
|
+ agent.tools = tools
|
|
|
+ agent.sub_agents = sub_agents
|
|
|
+ agent.extra_params = extra_params
|
|
|
+ else:
|
|
|
+ agent = AgentConfiguration(
|
|
|
+ name=name,
|
|
|
+ display_name=display_name,
|
|
|
+ type=type_,
|
|
|
+ execution_model=execution_model,
|
|
|
+ system_prompt=system_prompt,
|
|
|
+ task_prompt=task_prompt,
|
|
|
+ tools=tools,
|
|
|
+ sub_agents=sub_agents,
|
|
|
+ extra_params=extra_params
|
|
|
+ )
|
|
|
+ session.add(agent)
|
|
|
+
|
|
|
+ session.commit()
|
|
|
+ return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
|
|
|
+
|
|
|
+@app.route("/api/getModuleList", methods=["GET"])
|
|
|
+def get_module_list():
|
|
|
+ """
|
|
|
+ 获取所有的模块列表
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ with app.session_maker() as session:
|
|
|
+ query = session.query(ServiceModule) \
|
|
|
+ .filter(ServiceModule.is_delete == 0)
|
|
|
+ data = query.all()
|
|
|
+ ret_data = [
|
|
|
+ {
|
|
|
+ 'id': module.id,
|
|
|
+ 'name': module.name,
|
|
|
+ 'display_name': module.display_name,
|
|
|
+ 'default_agent_type': module.default_agent_type,
|
|
|
+ 'default_agent_id': module.default_agent_id,
|
|
|
+ 'created_at': module.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
|
|
+ 'updated_at': module.updated_at.strftime('%Y-%m-%d %H:%M:%S')
|
|
|
+ }
|
|
|
+ for module in data
|
|
|
+ ]
|
|
|
+ return wrap_response(200, data=ret_data)
|
|
|
+
|
|
|
+@app.route("/api/getModuleConfiguration", methods=["GET"])
|
|
|
+def get_module_configuration():
|
|
|
+ """
|
|
|
+ 获取指定模块的配置
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ module_id = request.args.get('module_id')
|
|
|
+ if not module_id:
|
|
|
+ return wrap_response(404, msg='module_id is required')
|
|
|
+
|
|
|
+ with app.session_maker() as session:
|
|
|
+ module = session.query(ServiceModule).filter(ServiceModule.id == module_id).first()
|
|
|
+ if not module:
|
|
|
+ return wrap_response(404, msg='Module not found')
|
|
|
+
|
|
|
+ data = {
|
|
|
+ 'id': module.id,
|
|
|
+ 'name': module.name,
|
|
|
+ 'display_name': module.display_name,
|
|
|
+ 'default_agent_type': module.default_agent_type,
|
|
|
+ 'default_agent_id': module.default_agent_id,
|
|
|
+ 'created_at': module.created_at.strftime('%Y-%m-%d %H:%M:%S'),
|
|
|
+ 'updated_at': module.updated_at.strftime('%Y-%m-%d %H:%M:%S')
|
|
|
+ }
|
|
|
+ return wrap_response(200, data=data)
|
|
|
+
|
|
|
+@app.route("/api/saveModuleConfiguration", methods=["POST"])
|
|
|
+def save_module_configuration():
|
|
|
+ """
|
|
|
+ 保存模块配置
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ req_data = request.json
|
|
|
+ module_id = req_data.get('module_id', None)
|
|
|
+ name = req_data.get('name')
|
|
|
+ display_name = req_data.get('display_name', None)
|
|
|
+ default_agent_type = req_data.get('default_agent_type', 0)
|
|
|
+ default_agent_id = req_data.get('default_agent_id', None)
|
|
|
+
|
|
|
+ if not name:
|
|
|
+ return wrap_response(400, msg='name is required')
|
|
|
+
|
|
|
+ with app.session_maker() as session:
|
|
|
+ if module_id:
|
|
|
+ module_id = int(module_id)
|
|
|
+ module = session.query(ServiceModule).filter(ServiceModule.id == module_id).first()
|
|
|
+ if not module:
|
|
|
+ return wrap_response(404, msg='Module not found')
|
|
|
+ module.name = name
|
|
|
+ module.display_name = display_name
|
|
|
+ module.default_agent_type = default_agent_type
|
|
|
+ module.default_agent_id = default_agent_id
|
|
|
+ else:
|
|
|
+ module = ServiceModule(
|
|
|
+ name=name,
|
|
|
+ display_name=display_name,
|
|
|
+ default_agent_type=default_agent_type,
|
|
|
+ default_agent_id=default_agent_id
|
|
|
+ )
|
|
|
+ session.add(module)
|
|
|
+
|
|
|
+ session.commit()
|
|
|
+ return wrap_response(200, msg='Module configuration saved successfully', data={'id': module.id})
|
|
|
|
|
|
@app.errorhandler(werkzeug.exceptions.BadRequest)
|
|
|
def handle_bad_request(e):
|
|
@@ -345,6 +557,8 @@ if __name__ == '__main__':
|
|
|
chat_history_table=chat_history_db_config['table']
|
|
|
)
|
|
|
app.session_manager = session_manager
|
|
|
+ agent_db_engine = create_sql_engine(config['storage']['agent_state']['mysql'])
|
|
|
+ app.session_maker = sessionmaker(bind=agent_db_engine)
|
|
|
|
|
|
wecom_db_config = config['storage']['user_relation']
|
|
|
user_relation_manager = MySQLUserRelationManager(
|