|
@@ -1,6 +1,8 @@
|
|
#! /usr/bin/env python
|
|
#! /usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
# -*- coding: utf-8 -*-
|
|
# vim:fenc=utf-8
|
|
# vim:fenc=utf-8
|
|
|
|
+import json
|
|
|
|
+import time
|
|
import logging
|
|
import logging
|
|
from argparse import ArgumentParser
|
|
from argparse import ArgumentParser
|
|
|
|
|
|
@@ -9,6 +11,10 @@ from flask import Flask, request, jsonify
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
|
|
from pqai_agent import configs
|
|
from pqai_agent import configs
|
|
|
|
+
|
|
|
|
+from pqai_agent import chat_service, prompt_templates
|
|
|
|
+from pqai_agent.logging import logger, setup_root_logger
|
|
|
|
+from pqai_agent.toolkit import global_tool_map
|
|
from pqai_agent import logging_service, chat_service, prompt_templates
|
|
from pqai_agent import logging_service, chat_service, prompt_templates
|
|
from pqai_agent.data_models.agent_configuration import AgentConfiguration
|
|
from pqai_agent.data_models.agent_configuration import AgentConfiguration
|
|
from pqai_agent.data_models.service_module import ServiceModule
|
|
from pqai_agent.data_models.service_module import ServiceModule
|
|
@@ -21,6 +27,8 @@ from pqai_agent_server.const.status_enum import TestTaskStatus
|
|
from pqai_agent_server.const.type_enum import EvaluateType
|
|
from pqai_agent_server.const.type_enum import EvaluateType
|
|
from pqai_agent_server.dataset_service import DatasetService
|
|
from pqai_agent_server.dataset_service import DatasetService
|
|
from pqai_agent_server.models import MySQLSessionManager
|
|
from pqai_agent_server.models import MySQLSessionManager
|
|
|
|
+import pqai_agent_server.utils
|
|
|
|
+from pqai_agent_server.utils import wrap_response
|
|
from pqai_agent_server.task_server import TaskManager
|
|
from pqai_agent_server.task_server import TaskManager
|
|
from pqai_agent_server.utils import (
|
|
from pqai_agent_server.utils import (
|
|
run_extractor_prompt,
|
|
run_extractor_prompt,
|
|
@@ -30,7 +38,6 @@ from pqai_agent_server.utils import (
|
|
from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
|
|
from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
|
|
|
|
|
|
app = Flask('agent_api_server')
|
|
app = Flask('agent_api_server')
|
|
-logger = logging_service.logger
|
|
|
|
const = AgentApiConst()
|
|
const = AgentApiConst()
|
|
|
|
|
|
|
|
|
|
@@ -93,34 +100,23 @@ def get_dialogue_history():
|
|
|
|
|
|
@app.route('/api/listModels', methods=['GET'])
|
|
@app.route('/api/listModels', methods=['GET'])
|
|
def list_models():
|
|
def list_models():
|
|
- models = [
|
|
|
|
- {
|
|
|
|
- 'model_type': 'openai_compatible',
|
|
|
|
- 'model_name': chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
|
|
|
|
- 'display_name': 'DeepSeek V3 on 火山'
|
|
|
|
- },
|
|
|
|
- {
|
|
|
|
- 'model_type': 'openai_compatible',
|
|
|
|
- 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
|
|
|
|
- 'display_name': '豆包Pro 32K'
|
|
|
|
- },
|
|
|
|
- {
|
|
|
|
- 'model_type': 'openai_compatible',
|
|
|
|
- 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
|
|
|
|
- 'display_name': '豆包Pro 1.5'
|
|
|
|
- },
|
|
|
|
- {
|
|
|
|
- 'model_type': 'openai_compatible',
|
|
|
|
- 'model_name': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH,
|
|
|
|
- 'display_name': 'DeepSeek V3联网 on 火山'
|
|
|
|
- },
|
|
|
|
|
|
+ models = {
|
|
|
|
+ "deepseek-chat": chat_service.VOLCENGINE_MODEL_DEEPSEEK_V3,
|
|
|
|
+ "gpt-4o": chat_service.OPENAI_MODEL_GPT_4o,
|
|
|
|
+ "gpt-4o-mini": chat_service.OPENAI_MODEL_GPT_4o_mini,
|
|
|
|
+ "doubao-pro-32k": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
|
|
|
|
+ "doubao-pro-1.5": chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
|
|
|
|
+ "doubao-1.5-vision-pro": chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
|
|
|
|
+ }
|
|
|
|
+ ret_data = [
|
|
{
|
|
{
|
|
'model_type': 'openai_compatible',
|
|
'model_type': 'openai_compatible',
|
|
- 'model_name': chat_service.VOLCENGINE_MODEL_DOUBAO_1_5_VISION_PRO,
|
|
|
|
- 'display_name': '豆包1.5视觉理解Pro'
|
|
|
|
- },
|
|
|
|
|
|
+ 'model_name': model_name,
|
|
|
|
+ 'display_name': model_display_name
|
|
|
|
+ }
|
|
|
|
+ for model_display_name, model_name in models.items()
|
|
]
|
|
]
|
|
- return wrap_response(200, data=models)
|
|
|
|
|
|
+ return wrap_response(200, data=ret_data)
|
|
|
|
|
|
|
|
|
|
@app.route('/api/listScenes', methods=['GET'])
|
|
@app.route('/api/listScenes', methods=['GET'])
|
|
@@ -148,8 +144,8 @@ def get_base_prompt():
|
|
model_map = {
|
|
model_map = {
|
|
'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
|
|
'greeting': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
|
|
'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
|
|
'chitchat': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_32K,
|
|
- 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
|
|
|
|
- 'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5,
|
|
|
|
|
|
+ 'profile_extractor': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
|
|
|
|
+ 'response_type_detector': chat_service.VOLCENGINE_MODEL_DOUBAO_PRO_1_5_32K,
|
|
'custom_debugging': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH
|
|
'custom_debugging': chat_service.VOLCENGINE_BOT_DEEPSEEK_V3_SEARCH
|
|
}
|
|
}
|
|
if scene not in prompt_map:
|
|
if scene not in prompt_map:
|
|
@@ -180,7 +176,6 @@ def run_prompt():
|
|
logger.error(e)
|
|
logger.error(e)
|
|
return wrap_response(500, msg='Error: {}'.format(e))
|
|
return wrap_response(500, msg='Error: {}'.format(e))
|
|
|
|
|
|
-
|
|
|
|
@app.route('/api/formatForPrompt', methods=['POST'])
|
|
@app.route('/api/formatForPrompt', methods=['POST'])
|
|
def format_data_for_prompt():
|
|
def format_data_for_prompt():
|
|
try:
|
|
try:
|
|
@@ -300,8 +295,8 @@ def send_message():
|
|
return wrap_response(200, msg="暂不实现功能")
|
|
return wrap_response(200, msg="暂不实现功能")
|
|
|
|
|
|
|
|
|
|
-@app.route("/api/quitHumanInterventionStatus", methods=["POST"])
|
|
|
|
-def quit_human_interventions_status():
|
|
|
|
|
|
+@app.route("/api/quitHumanIntervention", methods=["POST"])
|
|
|
|
+def quit_human_intervention():
|
|
"""
|
|
"""
|
|
退出人工介入状态
|
|
退出人工介入状态
|
|
:return:
|
|
:return:
|
|
@@ -311,10 +306,27 @@ def quit_human_interventions_status():
|
|
user_id = req_data["user_id"]
|
|
user_id = req_data["user_id"]
|
|
if not user_id or not staff_id:
|
|
if not user_id or not staff_id:
|
|
return wrap_response(404, msg="user_id and staff_id are required")
|
|
return wrap_response(404, msg="user_id and staff_id are required")
|
|
- response = quit_human_intervention_status(user_id, staff_id)
|
|
|
|
|
|
+ if pqai_agent_server.utils.common.quit_human_intervention(user_id, staff_id):
|
|
|
|
+ return wrap_response(200, msg="success")
|
|
|
|
+ else:
|
|
|
|
+ return wrap_response(500, msg="error")
|
|
|
|
|
|
- return wrap_response(200, data=response)
|
|
|
|
|
|
|
|
|
|
+@app.route("/api/enterHumanIntervention", methods=["POST"])
|
|
|
|
+def enter_human_intervention():
|
|
|
|
+ """
|
|
|
|
+ 进入人工介入状态
|
|
|
|
+ :return:
|
|
|
|
+ """
|
|
|
|
+ req_data = request.json
|
|
|
|
+ staff_id = req_data["staff_id"]
|
|
|
|
+ user_id = req_data["user_id"]
|
|
|
|
+ if not user_id or not staff_id:
|
|
|
|
+ return wrap_response(404, msg="user_id and staff_id are required")
|
|
|
|
+ if pqai_agent_server.utils.common.enter_human_intervention(user_id, staff_id):
|
|
|
|
+ return wrap_response(200, msg="success")
|
|
|
|
+ else:
|
|
|
|
+ return wrap_response(500, msg="error")
|
|
|
|
|
|
## Agent管理接口
|
|
## Agent管理接口
|
|
@app.route("/api/getNativeAgentList", methods=["GET"])
|
|
@app.route("/api/getNativeAgentList", methods=["GET"])
|
|
@@ -336,23 +348,28 @@ def get_native_agent_list():
|
|
query = query.filter(AgentConfiguration.create_user == create_user)
|
|
query = query.filter(AgentConfiguration.create_user == create_user)
|
|
if update_user:
|
|
if update_user:
|
|
query = query.filter(AgentConfiguration.update_user == update_user)
|
|
query = query.filter(AgentConfiguration.update_user == update_user)
|
|
|
|
+ total = query.count()
|
|
query = query.offset(offset).limit(int(page_size))
|
|
query = query.offset(offset).limit(int(page_size))
|
|
data = query.all()
|
|
data = query.all()
|
|
- ret_data = [
|
|
|
|
- {
|
|
|
|
- 'id': agent.id,
|
|
|
|
- 'name': agent.name,
|
|
|
|
- 'display_name': agent.display_name,
|
|
|
|
- 'type': agent.type,
|
|
|
|
- 'execution_model': agent.execution_model,
|
|
|
|
- 'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
|
|
|
|
- 'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
|
|
|
|
- }
|
|
|
|
- for agent in data
|
|
|
|
- ]
|
|
|
|
|
|
+ ret_data = {
|
|
|
|
+ 'total': total,
|
|
|
|
+ 'agent_list': [
|
|
|
|
+ {
|
|
|
|
+ 'id': agent.id,
|
|
|
|
+ 'name': agent.name,
|
|
|
|
+ 'display_name': agent.display_name,
|
|
|
|
+ 'type': agent.type,
|
|
|
|
+ 'execution_model': agent.execution_model,
|
|
|
|
+ 'create_user': agent.create_user,
|
|
|
|
+ 'update_user': agent.update_user,
|
|
|
|
+ 'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
|
|
|
|
+ 'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
|
|
|
|
+ }
|
|
|
|
+ for agent in data
|
|
|
|
+ ]
|
|
|
|
+ }
|
|
return wrap_response(200, data=ret_data)
|
|
return wrap_response(200, data=ret_data)
|
|
|
|
|
|
-
|
|
|
|
@app.route("/api/getNativeAgentConfiguration", methods=["GET"])
|
|
@app.route("/api/getNativeAgentConfiguration", methods=["GET"])
|
|
def get_native_agent_configuration():
|
|
def get_native_agent_configuration():
|
|
"""
|
|
"""
|
|
@@ -376,15 +393,14 @@ def get_native_agent_configuration():
|
|
'execution_model': agent.execution_model,
|
|
'execution_model': agent.execution_model,
|
|
'system_prompt': agent.system_prompt,
|
|
'system_prompt': agent.system_prompt,
|
|
'task_prompt': agent.task_prompt,
|
|
'task_prompt': agent.task_prompt,
|
|
- 'tools': agent.tools,
|
|
|
|
- 'sub_agents': agent.sub_agents,
|
|
|
|
- 'extra_params': agent.extra_params,
|
|
|
|
|
|
+ 'tools': json.loads(agent.tools),
|
|
|
|
+ 'sub_agents': json.loads(agent.sub_agents),
|
|
|
|
+ 'extra_params': json.loads(agent.extra_params),
|
|
'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
|
|
'create_time': agent.create_time.strftime('%Y-%m-%d %H:%M:%S'),
|
|
'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
|
|
'update_time': agent.update_time.strftime('%Y-%m-%d %H:%M:%S')
|
|
}
|
|
}
|
|
return wrap_response(200, data=data)
|
|
return wrap_response(200, data=data)
|
|
|
|
|
|
-
|
|
|
|
@app.route("/api/saveNativeAgentConfiguration", methods=["POST"])
|
|
@app.route("/api/saveNativeAgentConfiguration", methods=["POST"])
|
|
def save_native_agent_configuration():
|
|
def save_native_agent_configuration():
|
|
"""
|
|
"""
|
|
@@ -399,9 +415,19 @@ def save_native_agent_configuration():
|
|
execution_model = req_data.get('execution_model', None)
|
|
execution_model = req_data.get('execution_model', None)
|
|
system_prompt = req_data.get('system_prompt', None)
|
|
system_prompt = req_data.get('system_prompt', None)
|
|
task_prompt = req_data.get('task_prompt', None)
|
|
task_prompt = req_data.get('task_prompt', None)
|
|
- tools = req_data.get('tools', [])
|
|
|
|
- sub_agents = req_data.get('sub_agents', [])
|
|
|
|
|
|
+ tools = json.dumps(req_data.get('tools', []))
|
|
|
|
+ sub_agents = json.dumps(req_data.get('sub_agents', []))
|
|
extra_params = req_data.get('extra_params', {})
|
|
extra_params = req_data.get('extra_params', {})
|
|
|
|
+ operate_user = req_data.get('user', None)
|
|
|
|
+ if isinstance(extra_params, dict):
|
|
|
|
+ extra_params = json.dumps(extra_params)
|
|
|
|
+ elif isinstance(extra_params, str):
|
|
|
|
+ try:
|
|
|
|
+ json.loads(extra_params)
|
|
|
|
+ except json.JSONDecodeError:
|
|
|
|
+ return wrap_response(400, msg='extra_params should be a valid JSON object or string')
|
|
|
|
+ if not extra_params:
|
|
|
|
+ extra_params = '{}'
|
|
|
|
|
|
if not name:
|
|
if not name:
|
|
return wrap_response(400, msg='name is required')
|
|
return wrap_response(400, msg='name is required')
|
|
@@ -421,6 +447,7 @@ def save_native_agent_configuration():
|
|
agent.tools = tools
|
|
agent.tools = tools
|
|
agent.sub_agents = sub_agents
|
|
agent.sub_agents = sub_agents
|
|
agent.extra_params = extra_params
|
|
agent.extra_params = extra_params
|
|
|
|
+ agent.update_user = operate_user
|
|
else:
|
|
else:
|
|
agent = AgentConfiguration(
|
|
agent = AgentConfiguration(
|
|
name=name,
|
|
name=name,
|
|
@@ -431,7 +458,9 @@ def save_native_agent_configuration():
|
|
task_prompt=task_prompt,
|
|
task_prompt=task_prompt,
|
|
tools=tools,
|
|
tools=tools,
|
|
sub_agents=sub_agents,
|
|
sub_agents=sub_agents,
|
|
- extra_params=extra_params
|
|
|
|
|
|
+ extra_params=extra_params,
|
|
|
|
+ create_user=operate_user,
|
|
|
|
+ update_user=operate_user
|
|
)
|
|
)
|
|
session.add(agent)
|
|
session.add(agent)
|
|
|
|
|
|
@@ -439,6 +468,35 @@ def save_native_agent_configuration():
|
|
return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
|
|
return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
|
|
|
|
|
|
|
|
|
|
|
|
+
|
|
|
|
+@app.route("/api/deleteNativeAgentConfiguration", methods=["POST"])
|
|
|
|
+def delete_native_agent_configuration():
|
|
|
|
+ """
|
|
|
|
+ 删除指定Agent配置(软删除,设置is_delete=1)
|
|
|
|
+ :return:
|
|
|
|
+ """
|
|
|
|
+ req_data = request.json
|
|
|
|
+ agent_id = req_data.get('agent_id', None)
|
|
|
|
+ if not agent_id:
|
|
|
|
+ return wrap_response(400, msg='agent_id is required')
|
|
|
|
+ try:
|
|
|
|
+ agent_id = int(agent_id)
|
|
|
|
+ except ValueError:
|
|
|
|
+ return wrap_response(400, msg='agent_id must be an integer')
|
|
|
|
+
|
|
|
|
+ with app.session_maker() as session:
|
|
|
|
+ agent = session.query(AgentConfiguration).filter(
|
|
|
|
+ AgentConfiguration.id == agent_id,
|
|
|
|
+ AgentConfiguration.is_delete == 0
|
|
|
|
+ ).first()
|
|
|
|
+ if not agent:
|
|
|
|
+ return wrap_response(404, msg='Agent not found')
|
|
|
|
+ agent.is_delete = 1
|
|
|
|
+ session.commit()
|
|
|
|
+ return wrap_response(200, msg='Agent configuration deleted successfully')
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
@app.route("/api/getModuleList", methods=["GET"])
|
|
@app.route("/api/getModuleList", methods=["GET"])
|
|
def get_module_list():
|
|
def get_module_list():
|
|
"""
|
|
"""
|
|
@@ -463,7 +521,6 @@ def get_module_list():
|
|
]
|
|
]
|
|
return wrap_response(200, data=ret_data)
|
|
return wrap_response(200, data=ret_data)
|
|
|
|
|
|
-
|
|
|
|
@app.route("/api/getModuleConfiguration", methods=["GET"])
|
|
@app.route("/api/getModuleConfiguration", methods=["GET"])
|
|
def get_module_configuration():
|
|
def get_module_configuration():
|
|
"""
|
|
"""
|
|
@@ -490,7 +547,6 @@ def get_module_configuration():
|
|
}
|
|
}
|
|
return wrap_response(200, data=data)
|
|
return wrap_response(200, data=data)
|
|
|
|
|
|
-
|
|
|
|
@app.route("/api/saveModuleConfiguration", methods=["POST"])
|
|
@app.route("/api/saveModuleConfiguration", methods=["POST"])
|
|
def save_module_configuration():
|
|
def save_module_configuration():
|
|
"""
|
|
"""
|
|
@@ -673,6 +729,33 @@ def get_conversation_data_list():
|
|
return wrap_response(200, data=response)
|
|
return wrap_response(200, data=response)
|
|
|
|
|
|
|
|
|
|
|
|
+@app.route("/api/getToolList", methods=["GET"])
|
|
|
|
+def get_tool_list():
|
|
|
|
+ """
|
|
|
|
+ 获取所有的工具列表
|
|
|
|
+ :return:
|
|
|
|
+ """
|
|
|
|
+ tools = []
|
|
|
|
+ for tool_name, tool in global_tool_map.items():
|
|
|
|
+ tools.append({
|
|
|
|
+ 'name': tool_name,
|
|
|
|
+ 'description': tool.get_function_description(),
|
|
|
|
+ 'parameters': tool.parameters if hasattr(tool, 'parameters') else {}
|
|
|
|
+ })
|
|
|
|
+ return wrap_response(200, data=tools)
|
|
|
|
+
|
|
|
|
+@app.route("/api/getModuleAgentTypes", methods=["GET"])
|
|
|
|
+def get_agent_types():
|
|
|
|
+ """
|
|
|
|
+ 获取所有的Agent类型
|
|
|
|
+ :return:
|
|
|
|
+ """
|
|
|
|
+ agent_types = [
|
|
|
|
+ {'type': 0, 'display_name': '原生'},
|
|
|
|
+ {'type': 1, 'display_name': 'Coze'}
|
|
|
|
+ ]
|
|
|
|
+ return wrap_response(200, data=agent_types)
|
|
|
|
+
|
|
@app.errorhandler(werkzeug.exceptions.BadRequest)
|
|
@app.errorhandler(werkzeug.exceptions.BadRequest)
|
|
def handle_bad_request(e):
|
|
def handle_bad_request(e):
|
|
logger.error(e)
|
|
logger.error(e)
|
|
@@ -689,7 +772,7 @@ if __name__ == '__main__':
|
|
|
|
|
|
config = configs.get()
|
|
config = configs.get()
|
|
logging_level = logging.getLevelName(args.log_level)
|
|
logging_level = logging.getLevelName(args.log_level)
|
|
- logging_service.setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
|
|
|
|
|
|
+ setup_root_logger(level=logging_level, logfile_name='agent_api_server.log')
|
|
|
|
|
|
# set db config
|
|
# set db config
|
|
agent_db_config = config['database']['ai_agent']
|
|
agent_db_config = config['database']['ai_agent']
|
|
@@ -700,7 +783,7 @@ if __name__ == '__main__':
|
|
chat_history_db_config = config['storage']['chat_history']
|
|
chat_history_db_config = config['storage']['chat_history']
|
|
|
|
|
|
# init user manager
|
|
# init user manager
|
|
- user_manager = MySQLUserManager(agent_db_config, growth_db_config, staff_db_config['table'])
|
|
|
|
|
|
+ user_manager = MySQLUserManager(agent_db_config, user_db_config['table'], staff_db_config['table'])
|
|
app.user_manager = user_manager
|
|
app.user_manager = user_manager
|
|
|
|
|
|
# init session manager
|
|
# init session manager
|