Selaa lähdekoodia

修改数据库操作 改成ORM

xueyiming 1 viikko sitten
vanhempi
commit
1d423c0524

+ 22 - 0
pqai_agent/data_models/agent_test_task.py

@@ -0,0 +1,22 @@
+from enum import Enum
+
+from sqlalchemy import Column, Integer, Text, BigInteger, String, SmallInteger, Boolean, TIMESTAMP
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+
+
+
+
+class AgentTestTask(Base):
+    __tablename__ = "agent_test_task"
+
+    id = Column(BigInteger, primary_key=True, autoincrement=True, comment="主键id")
+    agent_id = Column(BigInteger, nullable=False, comment="agent主键")
+    model_id = Column(BigInteger, nullable=False, comment="model主键")
+    create_user = Column(String(32), nullable=True, comment="创建用户")
+    update_user = Column(String(32), nullable=True, comment="更新用户")
+    dataset_ids = Column(Text, nullable=False, comment="数据集ids")
+    status = Column(Integer, nullable=False, comment="状态(0:未开始, 1:进行中, 2:已完成, 3:已取消)")
+    create_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", comment="创建时间")
+    update_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP", comment="更新时间")

+ 9 - 6
pqai_agent_server/api_server.py

@@ -641,12 +641,7 @@ if __name__ == '__main__':
     user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
     app.user_manager = user_manager
 
-    task_manager = TaskManager(
-        db_config=user_db_config['mysql'],
-        agent_configuration_table=agent_configuration_db_config['table'],
-        test_task_table=test_task_db_config['table'],
-        test_task_conversations_table=test_task_conversations_db_config['table'])
-    app.task_manager = task_manager
+
 
     # init session manager
     session_manager = MySQLSessionManager(
@@ -660,6 +655,14 @@ if __name__ == '__main__':
     agent_db_engine = create_sql_engine(config['storage']['agent_state']['mysql'])
     app.session_maker = sessionmaker(bind=agent_db_engine)
 
+    task_manager = TaskManager(
+        session_maker = sessionmaker(bind=agent_db_engine),
+        db_config=user_db_config['mysql'],
+        agent_configuration_table=agent_configuration_db_config['table'],
+        test_task_table=test_task_db_config['table'],
+        test_task_conversations_table=test_task_conversations_db_config['table'])
+    app.task_manager = task_manager
+
     wecom_db_config = config['storage']['user_relation']
     user_relation_manager = MySQLUserRelationManager(
         user_db_config['mysql'], wecom_db_config['mysql'],

+ 76 - 35
pqai_agent_server/task_server.py

@@ -8,12 +8,17 @@ from typing import List, Dict, Optional
 import pymysql
 from pymysql import Connection
 from pymysql.cursors import DictCursor
+from sqlalchemy import func
 
 from pqai_agent import logging_service
+from pqai_agent.data_models.agent_configuration import AgentConfiguration
+from pqai_agent.data_models.agent_test_task import AgentTestTask
 from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc, \
     get_test_task_conversations_status_desc
 
 logger = logging_service.logger
+
+
 class Database:
     """数据库操作类"""
 
@@ -110,8 +115,10 @@ class Database:
 class TaskManager:
     """任务管理器"""
 
-    def __init__(self, db_config, agent_configuration_table, test_task_table, test_task_conversations_table,
+    def __init__(self, session_maker, db_config, agent_configuration_table, test_task_table,
+                 test_task_conversations_table,
                  max_workers: int = 10):
+        self.session_maker = session_maker
         self.db = Database(db_config)
         self.agent_configuration_table = agent_configuration_table
         self.test_task_table = test_task_table
@@ -122,39 +129,72 @@ class TaskManager:
         self.executor = ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix='TaskWorker')
         self.task_futures = {}  # 任务ID -> Future
 
-    def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
-        fetch_query = f"""
-            select t1.id, t2.name, t1.create_user, t1.update_user, t1.status, t1.create_time, t1.update_time
-            from {self.test_task_table} t1
-            left join {self.agent_configuration_table} t2 on t1.agent_id = t2.id
-            order by create_time desc
-            limit %s, %s;
-        """
-        messages = self.db.fetch(fetch_query, ((page_num - 1) * page_size, page_size))
-        total_size = self.db.fetch_one(f"""select count(*) as `count` from {self.test_task_table}""")
+    # def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
+    #     fetch_query = f"""
+    #         select t1.id, t2.name, t1.create_user, t1.update_user, t1.status, t1.create_time, t1.update_time
+    #         from {self.test_task_table} t1
+    #         left join {self.agent_configuration_table} t2 on t1.agent_id = t2.id
+    #         order by create_time desc
+    #         limit %s, %s;
+    #     """
+    #     messages = self.db.fetch(fetch_query, ((page_num - 1) * page_size, page_size))
+    #     total_size = self.db.fetch_one(f"""select count(*) as `count` from {self.test_task_table}""")
+    #
+    #     total = total_size["count"]
+    #     total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
+    #     total_page = 1 if total_page <= 0 else total_page
+    #     response_data = [
+    #         {
+    #             "id": message["id"],
+    #             "agentName": message["name"],
+    #             "createUser": message["create_user"],
+    #             "updateUser": message["update_user"],
+    #             "statusName": get_test_task_status_desc(message["status"]),
+    #             "createTime": message["create_time"].strftime("%Y-%m-%d %H:%M:%S"),
+    #             "updateTime": message["update_time"].strftime("%Y-%m-%d %H:%M:%S")
+    #         }
+    #         for message in messages
+    #     ]
+    #     return {
+    #         "currentPage": page_num,
+    #         "pageSize": page_size,
+    #         "totalSize": total_page,
+    #         "total": total,
+    #         "list": response_data,
+    #     }
 
-        total = total_size["count"]
-        total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
-        total_page = 1 if total_page <= 0 else total_page
-        response_data = [
-            {
-                "id": message["id"],
-                "agentName": message["name"],
-                "createUser": message["create_user"],
-                "updateUser": message["update_user"],
-                "statusName": get_test_task_status_desc(message["status"]),
-                "createTime": message["create_time"].strftime("%Y-%m-%d %H:%M:%S"),
-                "updateTime": message["update_time"].strftime("%Y-%m-%d %H:%M:%S")
+    def get_test_task_list(self, page_num: int, page_size: int) -> Dict:
+        with self.session_maker() as session:
+            # 计算偏移量
+            offset = (page_num - 1) * page_size
+            # 查询分页数据
+            result = (session.query(AgentTestTask, AgentConfiguration)
+                      .outerjoin(AgentConfiguration, AgentTestTask.agent_id == AgentConfiguration.id)
+                      .limit(page_size).offset(offset).all())
+            # 查询总记录数
+            total = session.query(func.count(AgentTestTask.id)).scalar()
+
+            total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
+            total_page = 1 if total_page <= 0 else total_page
+            response_data = [
+                {
+                    "id": agent_test_task.id,
+                    "agentName": agent_configuration.name,
+                    "createUser": agent_test_task.create_user,
+                    "updateUser": agent_test_task.update_user,
+                    "statusName": get_test_task_status_desc(agent_test_task.status),
+                    "createTime": agent_test_task.create_time.strftime("%Y-%m-%d %H:%M:%S"),
+                    "updateTime": agent_test_task.update_time.strftime("%Y-%m-%d %H:%M:%S")
+                }
+                for agent_test_task, agent_configuration in result
+            ]
+            return {
+                "currentPage": page_num,
+                "pageSize": page_size,
+                "totalSize": total_page,
+                "total": total,
+                "list": response_data,
             }
-            for message in messages
-        ]
-        return {
-            "currentPage": page_num,
-            "pageSize": page_size,
-            "totalSize": total_page,
-            "total": total,
-            "list": response_data,
-        }
 
     def get_test_task_conversations(self, task_id: int, page_num: int, page_size: int) -> Dict:
         fetch_query = f"""
@@ -167,8 +207,9 @@ class TaskManager:
             limit %s, %s;
         """
         messages = self.db.fetch(fetch_query, (task_id, (page_num - 1) * page_size, page_size))
-        total_size = self.db.fetch_one(f"""select count(*) as `count` from {self.test_task_conversations_table} where task_id = %s""",
-                                       (task_id,))
+        total_size = self.db.fetch_one(
+            f"""select count(*) as `count` from {self.test_task_conversations_table} where task_id = %s""",
+            (task_id,))
 
         total = total_size["count"]
         total_page = total // page_size + 1 if total % page_size > 0 else total // page_size
@@ -195,7 +236,7 @@ class TaskManager:
             "list": response_data,
         }
 
-    def create_task(self, agent_id: int, model_id: int ) -> Dict:
+    def create_task(self, agent_id: int, model_id: int) -> Dict:
         """创建新任务"""
 
         conn = self.db.get_connection()