|
@@ -8,12 +8,17 @@ from typing import List, Dict, Optional
|
|
|
import pymysql
|
|
|
from pymysql import Connection
|
|
|
from pymysql.cursors import DictCursor
|
|
|
+from sqlalchemy import func
|
|
|
|
|
|
from pqai_agent import logging_service
|
|
|
+from pqai_agent.data_models.agent_configuration import AgentConfiguration
|
|
|
+from pqai_agent.data_models.agent_test_task import AgentTestTask
|
|
|
from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc, \
|
|
|
get_test_task_conversations_status_desc
|
|
|
|
|
|
logger = logging_service.logger
|
|
|
+
|
|
|
+
|
|
|
class Database:
|
|
|
"""数据库操作类"""
|
|
|
|
|
@@ -110,8 +115,10 @@ class Database:
|
|
|
class TaskManager:
|
|
|
"""任务管理器"""
|
|
|
|
|
|
- def __init__(self, db_config, agent_configuration_table, test_task_table, test_task_conversations_table,
|
|
|
+ def __init__(self, session_maker, db_config, agent_configuration_table, test_task_table,
|
|
|
+ test_task_conversations_table,
|
|
|
max_workers: int = 10):
|
|
|
+ self.session_maker = session_maker
|
|
|
self.db = Database(db_config)
|
|
|
self.agent_configuration_table = agent_configuration_table
|
|
|
self.test_task_table = test_task_table
|
|
@@ -122,39 +129,72 @@ class TaskManager:
|
|
|
self.executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix='TaskWorker')
|
|
|
self.task_futures = {} # 任务ID -> Future
|
|
|
|
|
|
- def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
|
|
|
- fetch_query = f"""
|
|
|
- select t1.id, t2.name, t1.create_user, t1.update_user, t1.status, t1.create_time, t1.update_time
|
|
|
- from {self.test_task_table} t1
|
|
|
- left join {self.agent_configuration_table} t2 on t1.agent_id = t2.id
|
|
|
- order by create_time desc
|
|
|
- limit %s, %s;
|
|
|
- """
|
|
|
- messages = self.db.fetch(fetch_query, ((page_num - 1) * page_size, page_size))
|
|
|
- total_size = self.db.fetch_one(f"""select count(*) as `count` from {self.test_task_table}""")
|
|
|
+ # def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
|
|
|
+ # fetch_query = f"""
|
|
|
+ # select t1.id, t2.name, t1.create_user, t1.update_user, t1.status, t1.create_time, t1.update_time
|
|
|
+ # from {self.test_task_table} t1
|
|
|
+ # left join {self.agent_configuration_table} t2 on t1.agent_id = t2.id
|
|
|
+ # order by create_time desc
|
|
|
+ # limit %s, %s;
|
|
|
+ # """
|
|
|
+ # messages = self.db.fetch(fetch_query, ((page_num - 1) * page_size, page_size))
|
|
|
+ # total_size = self.db.fetch_one(f"""select count(*) as `count` from {self.test_task_table}""")
|
|
|
+ #
|
|
|
+ # total = total_size["count"]
|
|
|
+ # total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
|
|
|
+ # total_page = 1 if total_page <= 0 else total_page
|
|
|
+ # response_data = [
|
|
|
+ # {
|
|
|
+ # "id": message["id"],
|
|
|
+ # "agentName": message["name"],
|
|
|
+ # "createUser": message["create_user"],
|
|
|
+ # "updateUser": message["update_user"],
|
|
|
+ # "statusName": get_test_task_status_desc(message["status"]),
|
|
|
+ # "createTime": message["create_time"].strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
+ # "updateTime": message["update_time"].strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+ # }
|
|
|
+ # for message in messages
|
|
|
+ # ]
|
|
|
+ # return {
|
|
|
+ # "currentPage": page_num,
|
|
|
+ # "pageSize": page_size,
|
|
|
+ # "totalSize": total_page,
|
|
|
+ # "total": total,
|
|
|
+ # "list": response_data,
|
|
|
+ # }
|
|
|
|
|
|
- total = total_size["count"]
|
|
|
- total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
|
|
|
- total_page = 1 if total_page <= 0 else total_page
|
|
|
- response_data = [
|
|
|
- {
|
|
|
- "id": message["id"],
|
|
|
- "agentName": message["name"],
|
|
|
- "createUser": message["create_user"],
|
|
|
- "updateUser": message["update_user"],
|
|
|
- "statusName": get_test_task_status_desc(message["status"]),
|
|
|
- "createTime": message["create_time"].strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
- "updateTime": message["update_time"].strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+ def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
|
|
|
+ with self.session_maker() as session:
|
|
|
+ # 计算偏移量
|
|
|
+ offset = (page_num - 1) * page_size
|
|
|
+ # 查询分页数据
|
|
|
+ result = (session.query(AgentTestTask, AgentConfiguration)
|
|
|
+ .outerjoin(AgentConfiguration, AgentTestTask.agent_id == AgentConfiguration.id)
|
|
|
+ .limit(page_size).offset(offset).all())
|
|
|
+ # 查询总记录数
|
|
|
+ total = session.query(func.count(AgentTestTask.id)).scalar()
|
|
|
+
|
|
|
+ total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
|
|
|
+ total_page = 1 if total_page <= 0 else total_page
|
|
|
+ response_data = [
|
|
|
+ {
|
|
|
+ "id": agent_test_task.id,
|
|
|
+ "agentName": agent_configuration.name,
|
|
|
+ "createUser": agent_test_task.create_user,
|
|
|
+ "updateUser": agent_test_task.update_user,
|
|
|
+ "statusName": get_test_task_status_desc(agent_test_task.status),
|
|
|
+ "createTime": agent_test_task.create_time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
|
+ "updateTime": agent_test_task.update_time.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+ }
|
|
|
+ for agent_test_task, agent_configuration in result
|
|
|
+ ]
|
|
|
+ return {
|
|
|
+ "currentPage": page_num,
|
|
|
+ "pageSize": page_size,
|
|
|
+ "totalSize": total_page,
|
|
|
+ "total": total,
|
|
|
+ "list": response_data,
|
|
|
}
|
|
|
- for message in messages
|
|
|
- ]
|
|
|
- return {
|
|
|
- "currentPage": page_num,
|
|
|
- "pageSize": page_size,
|
|
|
- "totalSize": total_page,
|
|
|
- "total": total,
|
|
|
- "list": response_data,
|
|
|
- }
|
|
|
|
|
|
def get_test_task_conversations(self, task_id: int, page_num: int, page_size: int) -> Dict:
|
|
|
fetch_query = f"""
|
|
@@ -167,8 +207,9 @@ class TaskManager:
|
|
|
limit %s, %s;
|
|
|
"""
|
|
|
messages = self.db.fetch(fetch_query, (task_id, (page_num - 1) * page_size, page_size))
|
|
|
- total_size = self.db.fetch_one(f"""select count(*) as `count` from {self.test_task_conversations_table} where task_id = %s""",
|
|
|
- (task_id,))
|
|
|
+ total_size = self.db.fetch_one(
|
|
|
+ f"""select count(*) as `count` from {self.test_task_conversations_table} where task_id = %s""",
|
|
|
+ (task_id,))
|
|
|
|
|
|
total = total_size["count"]
|
|
|
total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
|
|
@@ -195,7 +236,7 @@ class TaskManager:
|
|
|
"list": response_data,
|
|
|
}
|
|
|
|
|
|
- def create_task(self, agent_id: int, model_id: int ) -> Dict:
|
|
|
+ def create_task(self, agent_id: int, model_id: int) -> Dict:
|
|
|
"""创建新任务"""
|
|
|
|
|
|
conn = self.db.get_connection()
|