Browse Source

Update api_server: add agent and module APIs

StrayWarrior 4 days ago
parent
commit
c3d952b310
1 changed files with 214 additions and 0 deletions
  1. 214 0
      pqai_agent_server/api_server.py

+ 214 - 0
pqai_agent_server/api_server.py

@@ -7,12 +7,17 @@ import werkzeug.exceptions
 from flask import Flask, request, jsonify
 from argparse import ArgumentParser
 
+from sqlalchemy.orm import sessionmaker
+
 from pqai_agent import configs
 
 from pqai_agent import logging_service, chat_service, prompt_templates
 from pqai_agent.agents.message_reply_agent import MessageReplyAgent
+from pqai_agent.data_models.agent_configuration import AgentConfiguration
+from pqai_agent.data_models.service_module import ServiceModule
 from pqai_agent.history_dialogue_service import HistoryDialogueService
 from pqai_agent.user_manager import MySQLUserManager, MySQLUserRelationManager
+from pqai_agent.utils.db_utils import create_sql_engine
 from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
 from pqai_agent_server.const import AgentApiConst
 from pqai_agent_server.models import MySQLSessionManager
@@ -307,6 +312,213 @@ def quit_human_interventions_status():
 
     return wrap_response(200, data=response)
 
+## Agent管理接口
+@app.route("/api/getNativeAgentList", methods=["GET"])
+def get_native_agent_list():
+    """
+    获取所有的Agent列表
+    :return:
+    """
+    page = request.args.get('page', 1)
+    page_size = request.args.get('page_size', 20)
+    create_user = request.args.get('create_user', None)
+    update_user = request.args.get('update_user', None)
+
+    offset = (int(page) - 1) * int(page_size)
+    with app.session_maker() as session:
+        query = session.query(AgentConfiguration) \
+            .filter(AgentConfiguration.is_delete == 0) \
+            .offset(offset).limit(int(page_size))
+        if create_user:
+            query = query.filter(AgentConfiguration.create_user == create_user)
+        if update_user:
+            query = query.filter(AgentConfiguration.update_user == update_user)
+        data = query.all()
+    ret_data = [
+        {
+            'id': agent.id,
+            'name': agent.name,
+            'display_name': agent.display_name,
+            'type': agent.type,
+            'execution_model': agent.execution_model,
+            'created_at': agent.created_at.strftime('%Y-%m-%d %H:%M:%S'),
+            'updated_at': agent.updated_at.strftime('%Y-%m-%d %H:%M:%S')
+        }
+        for agent in data
+    ]
+    return wrap_response(200, data=ret_data)
+
+@app.route("/api/getAgentConfiguration", methods=["GET"])
+def get_agent_configuration():
+    """
+    获取指定Agent的配置
+    :return:
+    """
+    agent_id = request.args.get('agent_id')
+    if not agent_id:
+        return wrap_response(404, msg='agent_id is required')
+
+    with app.session_maker() as session:
+        agent = session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).first()
+        if not agent:
+            return wrap_response(404, msg='Agent not found')
+
+        data = {
+            'id': agent.id,
+            'name': agent.name,
+            'display_name': agent.display_name,
+            'type': agent.type,
+            'execution_model': agent.execution_model,
+            'system_prompt': agent.system_prompt,
+            'task_prompt': agent.task_prompt,
+            'tools': agent.tools,
+            'sub_agents': agent.sub_agents,
+            'extra_params': agent.extra_params,
+            'created_at': agent.created_at.strftime('%Y-%m-%d %H:%M:%S'),
+            'updated_at': agent.updated_at.strftime('%Y-%m-%d %H:%M:%S')
+        }
+        return wrap_response(200, data=data)
+
+@app.route("/api/saveAgentConfiguration", methods=["POST"])
+def save_agent_configuration():
+    """
+    保存Agent配置
+    :return:
+    """
+    req_data = request.json
+    agent_id = req_data.get('agent_id', None)
+    name = req_data.get('name')
+    display_name = req_data.get('display_name', None)
+    type_ = req_data.get('type', 0)
+    execution_model = req_data.get('execution_model', None)
+    system_prompt = req_data.get('system_prompt', None)
+    task_prompt = req_data.get('task_prompt', None)
+    tools = req_data.get('tools', [])
+    sub_agents = req_data.get('sub_agents', [])
+    extra_params = req_data.get('extra_params', {})
+
+    if not name:
+        return wrap_response(400, msg='name is required')
+
+    with app.session_maker() as session:
+        if agent_id:
+            agent_id = int(agent_id)
+            agent = session.query(AgentConfiguration).filter(AgentConfiguration.id == agent_id).first()
+            if not agent:
+                return wrap_response(404, msg='Agent not found')
+            agent.name = name
+            agent.display_name = display_name
+            agent.type = type_
+            agent.execution_model = execution_model
+            agent.system_prompt = system_prompt
+            agent.task_prompt = task_prompt
+            agent.tools = tools
+            agent.sub_agents = sub_agents
+            agent.extra_params = extra_params
+        else:
+            agent = AgentConfiguration(
+                name=name,
+                display_name=display_name,
+                type=type_,
+                execution_model=execution_model,
+                system_prompt=system_prompt,
+                task_prompt=task_prompt,
+                tools=tools,
+                sub_agents=sub_agents,
+                extra_params=extra_params
+            )
+            session.add(agent)
+
+        session.commit()
+        return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
+
+@app.route("/api/getModuleList", methods=["GET"])
+def get_module_list():
+    """
+    获取所有的模块列表
+    :return:
+    """
+    with app.session_maker() as session:
+        query = session.query(ServiceModule) \
+            .filter(ServiceModule.is_delete == 0)
+        data = query.all()
+    ret_data = [
+        {
+            'id': module.id,
+            'name': module.name,
+            'display_name': module.display_name,
+            'default_agent_type': module.default_agent_type,
+            'default_agent_id': module.default_agent_id,
+            'created_at': module.created_at.strftime('%Y-%m-%d %H:%M:%S'),
+            'updated_at': module.updated_at.strftime('%Y-%m-%d %H:%M:%S')
+        }
+        for module in data
+    ]
+    return wrap_response(200, data=ret_data)
+
+@app.route("/api/getModuleConfiguration", methods=["GET"])
+def get_module_configuration():
+    """
+    获取指定模块的配置
+    :return:
+    """
+    module_id = request.args.get('module_id')
+    if not module_id:
+        return wrap_response(404, msg='module_id is required')
+
+    with app.session_maker() as session:
+        module = session.query(ServiceModule).filter(ServiceModule.id == module_id).first()
+        if not module:
+            return wrap_response(404, msg='Module not found')
+
+        data = {
+            'id': module.id,
+            'name': module.name,
+            'display_name': module.display_name,
+            'default_agent_type': module.default_agent_type,
+            'default_agent_id': module.default_agent_id,
+            'created_at': module.created_at.strftime('%Y-%m-%d %H:%M:%S'),
+            'updated_at': module.updated_at.strftime('%Y-%m-%d %H:%M:%S')
+        }
+        return wrap_response(200, data=data)
+
+@app.route("/api/saveModuleConfiguration", methods=["POST"])
+def save_module_configuration():
+    """
+    保存模块配置
+    :return:
+    """
+    req_data = request.json
+    module_id = req_data.get('module_id', None)
+    name = req_data.get('name')
+    display_name = req_data.get('display_name', None)
+    default_agent_type = req_data.get('default_agent_type', 0)
+    default_agent_id = req_data.get('default_agent_id', None)
+
+    if not name:
+        return wrap_response(400, msg='name is required')
+
+    with app.session_maker() as session:
+        if module_id:
+            module_id = int(module_id)
+            module = session.query(ServiceModule).filter(ServiceModule.id == module_id).first()
+            if not module:
+                return wrap_response(404, msg='Module not found')
+            module.name = name
+            module.display_name = display_name
+            module.default_agent_type = default_agent_type
+            module.default_agent_id = default_agent_id
+        else:
+            module = ServiceModule(
+                name=name,
+                display_name=display_name,
+                default_agent_type=default_agent_type,
+                default_agent_id=default_agent_id
+            )
+            session.add(module)
+
+        session.commit()
+        return wrap_response(200, msg='Module configuration saved successfully', data={'id': module.id})
 
 @app.errorhandler(werkzeug.exceptions.BadRequest)
 def handle_bad_request(e):
@@ -345,6 +557,8 @@ if __name__ == '__main__':
         chat_history_table=chat_history_db_config['table']
     )
     app.session_manager = session_manager
+    agent_db_engine = create_sql_engine(config['storage']['agent_state']['mysql'])
+    app.session_maker = sessionmaker(bind=agent_db_engine)
 
     wecom_db_config = config['storage']['user_relation']
     user_relation_manager = MySQLUserRelationManager(