""" MySQL工具库测试用例 注意:运行测试前请确保: 1. 数据库连接配置正确 2. 有测试用的数据表(或者在测试中创建) 3. 有足够的数据库权限 """ import unittest import time import sys import os # 添加项目根目录到路径,支持直接运行测试 current_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.dirname(os.path.dirname(current_dir)) sys.path.insert(0, project_root) try: # 优先尝试相对导入(从包内运行) from . import mysql_db except ImportError: # 备用绝对导入(直接运行文件) from utils.mysql import mysql_db class TestMySQLUtils(unittest.TestCase): """MySQL工具库测试类""" @classmethod def setUpClass(cls): """测试类初始化""" print("开始MySQL工具库测试...") # 创建测试表(如果不存在) try: mysql_db.execute_update(""" CREATE TABLE IF NOT EXISTS test_users ( id INT AUTO_INCREMENT PRIMARY KEY, name VARCHAR(100) NOT NULL, email VARCHAR(100) UNIQUE, age INT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP ) """) print("测试表创建成功") except Exception as e: print(f"创建测试表失败: {e}") raise @classmethod def tearDownClass(cls): """测试类清理""" try: # 清理测试数据 mysql_db.execute_update("DELETE FROM test_users WHERE name LIKE 'Test%'") print("测试数据清理完成") except Exception as e: print(f"清理测试数据失败: {e}") def test_01_basic_insert(self): """测试基础插入操作""" print("\n测试基础插入操作...") test_data = { 'name': 'Test User 1', 'email': 'test1@example.com', 'age': 25 } user_id = mysql_db.insert('test_users', test_data) self.assertIsNotNone(user_id) self.assertGreater(user_id, 0) print(f"插入成功,ID: {user_id}") # 验证插入的数据 user = mysql_db.select_one('test_users', where='id = %s', where_params=(user_id,)) self.assertIsNotNone(user) self.assertEqual(user['name'], 'Test User 1') self.assertEqual(user['email'], 'test1@example.com') def test_02_basic_select(self): """测试基础查询操作""" print("\n测试基础查询操作...") # 查询所有测试用户 users = mysql_db.select('test_users', where="name LIKE %s", where_params=('Test%',)) self.assertGreaterEqual(len(users), 1) print(f"查询到 {len(users)} 个测试用户") # 查询单个用户 user = mysql_db.select_one('test_users', where="name = %s", where_params=('Test User 1',)) self.assertIsNotNone(user) self.assertEqual(user['name'], 'Test User 1') def test_03_basic_update(self): """测试基础更新操作""" print("\n测试基础更新操作...") # 更新用户年龄 affected_rows = mysql_db.update( 'test_users', {'age': 26}, 'name = %s', ('Test User 1',) ) self.assertGreater(affected_rows, 0) print(f"更新了 {affected_rows} 条记录") # 验证更新结果 user = mysql_db.select_one('test_users', where="name = %s", where_params=('Test User 1',)) self.assertEqual(user['age'], 26) def test_04_count_and_exists(self): """测试计数和存在性检查""" print("\n测试计数和存在性检查...") # 测试计数 count = mysql_db.count('test_users', where="name LIKE %s", where_params=('Test%',)) self.assertGreaterEqual(count, 1) print(f"测试用户总数: {count}") # 测试存在性检查 exists = mysql_db.exists('test_users', 'name = %s', ('Test User 1',)) self.assertTrue(exists) not_exists = mysql_db.exists('test_users', 'name = %s', ('Non Existent User',)) self.assertFalse(not_exists) def test_05_batch_insert(self): """测试批量插入""" print("\n测试批量插入...") users_data = [ {'name': 'Test User 2', 'email': 'test2@example.com', 'age': 22}, {'name': 'Test User 3', 'email': 'test3@example.com', 'age': 23}, {'name': 'Test User 4', 'email': 'test4@example.com', 'age': 24} ] affected_rows = mysql_db.insert_many('test_users', users_data) self.assertEqual(affected_rows, 3) print(f"批量插入了 {affected_rows} 条记录") def test_06_pagination(self): """测试分页查询""" print("\n测试分页查询...") result = mysql_db.paginate( 'test_users', page=1, page_size=2, where="name LIKE %s", where_params=('Test%',), order_by='id ASC' ) self.assertIn('data', result) self.assertIn('pagination', result) self.assertEqual(len(result['data']), 2) self.assertEqual(result['pagination']['current_page'], 1) self.assertEqual(result['pagination']['page_size'], 2) print(f"分页查询结果: 当前页 {result['pagination']['current_page']}, " f"总记录数 {result['pagination']['total_count']}") def test_07_sorting(self): """测试排序查询""" print("\n测试排序查询...") # 单字段排序 users = mysql_db.select_with_sort( 'test_users', where="name LIKE %s", where_params=('Test%',), sort_field='age', sort_order='DESC', limit=3 ) self.assertGreaterEqual(len(users), 1) print(f"按年龄降序查询到 {len(users)} 个用户") # 验证排序结果 if len(users) > 1: self.assertGreaterEqual(users[0]['age'], users[1]['age']) def test_08_aggregation(self): """测试聚合查询""" print("\n测试聚合查询...") agg_result = mysql_db.aggregate( 'test_users', { 'total_count': 'COUNT(*)', 'avg_age': 'AVG(age)', 'max_age': 'MAX(age)', 'min_age': 'MIN(age)' }, where="name LIKE %s", where_params=('Test%',) ) self.assertEqual(len(agg_result), 1) result = agg_result[0] self.assertGreater(result['total_count'], 0) self.assertIsNotNone(result['avg_age']) self.assertIsNotNone(result['max_age']) self.assertIsNotNone(result['min_age']) print(f"聚合查询结果: {result}") def test_09_search(self): """测试模糊搜索""" print("\n测试模糊搜索...") results = mysql_db.search( 'test_users', ['name', 'email'], 'Test', limit=10 ) self.assertGreaterEqual(len(results), 1) print(f"搜索到 {len(results)} 条记录") # 验证搜索结果 for result in results: self.assertTrue( 'Test' in result['name'] or 'Test' in (result['email'] or '') ) def test_10_transaction(self): """测试事务操作""" print("\n测试事务操作...") try: with mysql_db.transaction(): # 在事务中插入数据 user_id = mysql_db.insert('test_users', { 'name': 'Test Transaction User', 'email': 'trans@example.com', 'age': 30 }) # 更新刚插入的数据 mysql_db.update('test_users', {'age': 31}, 'id = %s', (user_id,)) print(f"事务中插入并更新用户,ID: {user_id}") # 验证事务提交后的数据 user = mysql_db.select_one('test_users', where='id = %s', where_params=(user_id,)) self.assertIsNotNone(user) self.assertEqual(user['age'], 31) except Exception as e: self.fail(f"事务测试失败: {e}") def test_11_transaction_rollback(self): """测试事务回滚""" print("\n测试事务回滚...") initial_count = mysql_db.count('test_users') print(f"事务前记录数: {initial_count}") try: with mysql_db.transaction() as conn: # 在事务中插入一个用户,传递连接参数 user_id = mysql_db.insert('test_users', { 'name': 'Test Rollback User', 'email': 'rollback@example.com', 'age': 35 }, connection=conn) print(f"事务中插入用户ID: {user_id}") # 人为抛出异常触发回滚 raise ValueError("测试回滚") except ValueError: # 这是预期的异常 print("捕获到预期的异常,事务应该已回滚") # 等待一下确保事务完全处理 import time time.sleep(0.1) # 验证回滚后数据没有增加 final_count = mysql_db.count('test_users') print(f"事务后记录数: {final_count}") if initial_count != final_count: # 如果计数不匹配,查看是否是我们插入的测试数据 rollback_user = mysql_db.select_one('test_users', where="name = %s", where_params=('Test Rollback User',)) if rollback_user: print("❌ 事务回滚失败,找到了应该被回滚的数据") # 手动清理这条数据 mysql_db.delete('test_users', 'name = %s', ('Test Rollback User',)) self.fail("事务回滚失败") else: print("✅ 虽然计数不同,但回滚用户确实不存在,可能是其他并发操作") print("事务回滚测试通过") def test_12_connection_pool(self): """测试连接池""" print("\n测试连接池...") # 获取连接池状态 status = mysql_db.pool.get_pool_status() self.assertIn('active_connections', status) self.assertIn('pool_size', status) self.assertIn('max_connections', status) print(f"连接池状态: {status}") # 测试并发获取连接 connections = [] try: for i in range(3): conn = mysql_db.pool.get_connection() connections.append(conn) # 验证连接可用性 for conn in connections: conn.ping() print("连接池并发测试通过") finally: # 归还连接 for conn in connections: mysql_db.pool.return_connection(conn) def test_13_error_handling(self): """测试错误处理""" print("\n测试错误处理...") # 测试查询不存在的表 with self.assertRaises(Exception): # 捕获任何异常,因为可能是pymysql原生异常 mysql_db.select('non_existent_table') # 测试插入空数据 with self.assertRaises(ValueError): mysql_db.insert('test_users', {}) print("错误处理测试通过") def test_14_performance(self): """性能测试""" print("\n性能测试...") # 测试批量插入性能 start_time = time.time() batch_data = [] for i in range(100): batch_data.append({ 'name': f'Perf Test User {i}', 'email': f'perf{i}@example.com', 'age': 20 + (i % 30) }) mysql_db.insert_many('test_users', batch_data) end_time = time.time() execution_time = end_time - start_time print(f"批量插入100条记录耗时: {execution_time:.4f}秒") self.assertLess(execution_time, 5.0) # 应该在5秒内完成 # 清理性能测试数据 mysql_db.delete('test_users', 'name LIKE %s', ('Perf Test User%',)) def test_15_cleanup(self): """清理测试数据""" print("\n清理额外的测试数据...") # 删除事务测试用户 mysql_db.delete('test_users', 'name = %s', ('Test Transaction User',)) print("清理完成") def run_tests(): """运行测试套件""" # 创建测试套件 test_suite = unittest.TestLoader().loadTestsFromTestCase(TestMySQLUtils) # 运行测试 runner = unittest.TextTestRunner(verbosity=2) result = runner.run(test_suite) # 输出测试结果 if result.wasSuccessful(): print(f"\n✅ 所有测试通过! 运行了 {result.testsRun} 个测试") else: print(f"\n❌ 测试失败! {len(result.failures)} 个失败, {len(result.errors)} 个错误") return result.wasSuccessful() if __name__ == '__main__': try: success = run_tests() exit(0 if success else 1) except Exception as e: print(f"测试运行出错: {e}") exit(1)