Browse Source

修改数据库操作 改成ORM

xueyiming 1 tuần trước cách đây
mục cha
commit
43322f9cf7

+ 0 - 6
pqai_agent/configs/dev.yaml

@@ -36,12 +36,6 @@ storage:
     table: qywx_chat_history
   push_record:
     table: agent_push_record_dev
-  agent_configuration:
-    table: agent_configuration
-  test_task:
-    table: agent_test_task
-  test_task_conversations:
-    table: agent_test_task_conversations
 
 agent_behavior:
   message_aggregation_sec: 3

+ 0 - 6
pqai_agent/configs/prod.yaml

@@ -36,12 +36,6 @@ storage:
     table: qywx_chat_history
   push_record:
     table: agent_push_record_dev
-  agent_configuration:
-    table: agent_configuration
-  test_task:
-    table: agent_test_task
-  test_task_conversations:
-    table: agent_test_task_conversations
 
 chat_api:
   coze:

+ 10 - 11
pqai_agent_server/api_server.py

@@ -34,6 +34,7 @@ app = Flask('agent_api_server')
 logger = logging_service.logger
 const = AgentApiConst()
 
+
 @app.route('/api/listStaffs', methods=['GET'])
 def list_staffs():
     staff_data = app.user_relation_manager.list_staffs()
@@ -180,6 +181,7 @@ def run_prompt():
         logger.error(e)
         return wrap_response(500, msg='Error: {}'.format(e))
 
+
 @app.route('/api/formatForPrompt', methods=['POST'])
 def format_data_for_prompt():
     try:
@@ -314,6 +316,7 @@ def quit_human_interventions_status():
 
     return wrap_response(200, data=response)
 
+
 ## Agent管理接口
 @app.route("/api/getNativeAgentList", methods=["GET"])
 def get_native_agent_list():
@@ -350,6 +353,7 @@ def get_native_agent_list():
     ]
     return wrap_response(200, data=ret_data)
 
+
 @app.route("/api/getNativeAgentConfiguration", methods=["GET"])
 def get_native_agent_configuration():
     """
@@ -381,6 +385,7 @@ def get_native_agent_configuration():
         }
         return wrap_response(200, data=data)
 
+
 @app.route("/api/saveNativeAgentConfiguration", methods=["POST"])
 def save_native_agent_configuration():
     """
@@ -434,6 +439,7 @@ def save_native_agent_configuration():
         session.commit()
         return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
 
+
 @app.route("/api/getModuleList", methods=["GET"])
 def get_module_list():
     """
@@ -458,6 +464,7 @@ def get_module_list():
     ]
     return wrap_response(200, data=ret_data)
 
+
 @app.route("/api/getModuleConfiguration", methods=["GET"])
 def get_module_configuration():
     """
@@ -484,6 +491,7 @@ def get_module_configuration():
         }
         return wrap_response(200, data=data)
 
+
 @app.route("/api/saveModuleConfiguration", methods=["POST"])
 def save_module_configuration():
     """
@@ -539,6 +547,7 @@ def get_test_task_list():
     response = app.task_manager.get_test_task_list(page_num, page_size)
     return wrap_response(200, data=response)
 
+
 @app.route("/api/getTestTaskConversations", methods=["GET"])
 def get_test_task_conversations():
     """
@@ -633,16 +642,11 @@ if __name__ == '__main__':
     staff_db_config = config['storage']['staff']
     agent_state_db_config = config['storage']['agent_state']
     chat_history_db_config = config['storage']['chat_history']
-    agent_configuration_db_config = config['storage']['agent_configuration']
-    test_task_db_config = config['storage']['test_task']
-    test_task_conversations_db_config = config['storage']['test_task_conversations']
 
     # init user manager
     user_manager = MySQLUserManager(user_db_config['mysql'], user_db_config['table'], staff_db_config['table'])
     app.user_manager = user_manager
 
-
-
     # init session manager
     session_manager = MySQLSessionManager(
         db_config=user_db_config['mysql'],
@@ -655,12 +659,7 @@ if __name__ == '__main__':
     agent_db_engine = create_sql_engine(config['storage']['agent_state']['mysql'])
     app.session_maker = sessionmaker(bind=agent_db_engine)
 
-    task_manager = TaskManager(
-        session_maker = sessionmaker(bind=agent_db_engine),
-        db_config=user_db_config['mysql'],
-        agent_configuration_table=agent_configuration_db_config['table'],
-        test_task_table=test_task_db_config['table'],
-        test_task_conversations_table=test_task_conversations_db_config['table'])
+    task_manager = TaskManager(session_maker=sessionmaker(bind=agent_db_engine))
     app.task_manager = task_manager
 
     wecom_db_config = config['storage']['user_relation']

+ 3 - 101
pqai_agent_server/task_server.py

@@ -20,110 +20,11 @@ from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, Tes
 logger = logging_service.logger
 
 
-class Database:
-    """数据库操作类"""
-
-    def __init__(self, db_config):
-        self.db_config = db_config
-        self.connection_pool = Queue(maxsize=10)
-        self._initialize_pool()
-
-    def _initialize_pool(self):
-        """初始化数据库连接池"""
-        for _ in range(5):
-            conn = pymysql.connect(**self.db_config)
-            self.connection_pool.put(conn)
-        logger.info("Database connection pool initialized with 5 connections")
-
-    def get_connection(self) -> Connection:
-        """从连接池获取数据库连接"""
-        return self.connection_pool.get()
-
-    def release_connection(self, conn: Connection):
-        """释放数据库连接回连接池"""
-        self.connection_pool.put(conn)
-
-    def execute(self, query: str, args: tuple = (), many: bool = False) -> int:
-        """执行SQL语句并返回影响的行数"""
-        conn = self.get_connection()
-        try:
-            with conn.cursor() as cursor:
-                if many:
-                    cursor.executemany(query, args)
-                else:
-                    cursor.execute(query, args)
-                conn.commit()
-                return cursor.rowcount
-        except Exception as e:
-            logger.error(f"Database error: {str(e)}")
-            conn.rollback()
-            raise
-        finally:
-            self.release_connection(conn)
-
-    def insert(self, insert: str, args: tuple = (), many: bool = False) -> int:
-        """执行插入SQL语句并主键"""
-        conn = self.get_connection()
-        try:
-            with conn.cursor() as cursor:
-                if many:
-                    cursor.executemany(insert, args)
-                else:
-                    cursor.execute(insert, args)
-                conn.commit()
-                return cursor.lastrowid
-        except Exception as e:
-            logger.error(f"Database error: {str(e)}")
-            conn.rollback()
-            raise
-        finally:
-            self.release_connection(conn)
-
-    def fetch(self, query: str, args: tuple = ()) -> List[Dict]:
-        """执行SQL查询并返回结果列表"""
-        conn = self.get_connection()
-        try:
-            with conn.cursor(DictCursor) as cursor:
-                cursor.execute(query, args)
-                return cursor.fetchall()
-        except Exception as e:
-            logger.error(f"Database error: {str(e)}")
-            raise
-        finally:
-            self.release_connection(conn)
-
-    def fetch_one(self, query: str, args: tuple = ()) -> Optional[Dict]:
-        """执行SQL查询并返回单行结果"""
-        conn = self.get_connection()
-        try:
-            with conn.cursor(DictCursor) as cursor:
-                cursor.execute(query, args)
-                return cursor.fetchone()
-        except Exception as e:
-            logger.error(f"Database error: {str(e)}")
-            raise
-        finally:
-            self.release_connection(conn)
-
-    def close_all(self):
-        """关闭所有数据库连接"""
-        while not self.connection_pool.empty():
-            conn = self.connection_pool.get()
-            conn.close()
-        logger.info("All database connections closed")
-
-
 class TaskManager:
     """任务管理器"""
 
-    def __init__(self, session_maker, db_config, agent_configuration_table, test_task_table,
-                 test_task_conversations_table,
-                 max_workers: int = 10):
+    def __init__(self, session_maker, max_workers: int = 10):
         self.session_maker = session_maker
-        self.db = Database(db_config)
-        self.agent_configuration_table = agent_configuration_table
-        self.test_task_table = test_task_table
-        self.test_task_conversations_table = test_task_conversations_table
         self.task_events = {}  # 任务ID -> Event (用于取消任务)
         self.task_locks = {}  # 任务ID -> Lock (用于任务状态同步)
         self.running_tasks = set()
@@ -233,7 +134,8 @@ class TaskManager:
     def get_pending_task_conversations(self, task_id: int):
         """获取待处理的子任务"""
         with self.session_maker() as session:
-            return session.query(AgentTestTaskConversations).filter(AgentTestTaskConversations.task_id == task_id).filter(
+            return session.query(AgentTestTaskConversations).filter(
+                AgentTestTaskConversations.task_id == task_id).filter(
                 AgentTestTaskConversations.status == TestTaskConversationsStatus.PENDING.value).all()
 
     def update_task_status(self, task_id: int, status: int):