| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410 |
- """
- MySQL工具库测试用例
- 注意:运行测试前请确保:
- 1. 数据库连接配置正确
- 2. 有测试用的数据表(或者在测试中创建)
- 3. 有足够的数据库权限
- """
- import unittest
- import time
- import sys
- import os
- # 添加项目根目录到路径,支持直接运行测试
- current_dir = os.path.dirname(os.path.abspath(__file__))
- project_root = os.path.dirname(os.path.dirname(current_dir))
- sys.path.insert(0, project_root)
- try:
- # 优先尝试相对导入(从包内运行)
- from . import mysql_db
- except ImportError:
- # 备用绝对导入(直接运行文件)
- from utils.mysql import mysql_db
- class TestMySQLUtils(unittest.TestCase):
- """MySQL工具库测试类"""
- @classmethod
- def setUpClass(cls):
- """测试类初始化"""
- print("开始MySQL工具库测试...")
- # 创建测试表(如果不存在)
- try:
- mysql_db.execute_update("""
- CREATE TABLE IF NOT EXISTS test_users (
- id INT AUTO_INCREMENT PRIMARY KEY,
- name VARCHAR(100) NOT NULL,
- email VARCHAR(100) UNIQUE,
- age INT,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
- )
- """)
- print("测试表创建成功")
- except Exception as e:
- print(f"创建测试表失败: {e}")
- raise
- @classmethod
- def tearDownClass(cls):
- """测试类清理"""
- try:
- # 清理测试数据
- mysql_db.execute_update("DELETE FROM test_users WHERE name LIKE 'Test%'")
- print("测试数据清理完成")
- except Exception as e:
- print(f"清理测试数据失败: {e}")
- def test_01_basic_insert(self):
- """测试基础插入操作"""
- print("\n测试基础插入操作...")
- test_data = {
- 'name': 'Test User 1',
- 'email': 'test1@example.com',
- 'age': 25
- }
- user_id = mysql_db.insert('test_users', test_data)
- self.assertIsNotNone(user_id)
- self.assertGreater(user_id, 0)
- print(f"插入成功,ID: {user_id}")
- # 验证插入的数据
- user = mysql_db.select_one('test_users', where='id = %s', where_params=(user_id,))
- self.assertIsNotNone(user)
- self.assertEqual(user['name'], 'Test User 1')
- self.assertEqual(user['email'], 'test1@example.com')
- def test_02_basic_select(self):
- """测试基础查询操作"""
- print("\n测试基础查询操作...")
- # 查询所有测试用户
- users = mysql_db.select('test_users', where="name LIKE %s", where_params=('Test%',))
- self.assertGreaterEqual(len(users), 1)
- print(f"查询到 {len(users)} 个测试用户")
- # 查询单个用户
- user = mysql_db.select_one('test_users', where="name = %s", where_params=('Test User 1',))
- self.assertIsNotNone(user)
- self.assertEqual(user['name'], 'Test User 1')
- def test_03_basic_update(self):
- """测试基础更新操作"""
- print("\n测试基础更新操作...")
- # 更新用户年龄
- affected_rows = mysql_db.update(
- 'test_users',
- {'age': 26},
- 'name = %s',
- ('Test User 1',)
- )
- self.assertGreater(affected_rows, 0)
- print(f"更新了 {affected_rows} 条记录")
- # 验证更新结果
- user = mysql_db.select_one('test_users', where="name = %s", where_params=('Test User 1',))
- self.assertEqual(user['age'], 26)
- def test_04_count_and_exists(self):
- """测试计数和存在性检查"""
- print("\n测试计数和存在性检查...")
- # 测试计数
- count = mysql_db.count('test_users', where="name LIKE %s", where_params=('Test%',))
- self.assertGreaterEqual(count, 1)
- print(f"测试用户总数: {count}")
- # 测试存在性检查
- exists = mysql_db.exists('test_users', 'name = %s', ('Test User 1',))
- self.assertTrue(exists)
- not_exists = mysql_db.exists('test_users', 'name = %s', ('Non Existent User',))
- self.assertFalse(not_exists)
- def test_05_batch_insert(self):
- """测试批量插入"""
- print("\n测试批量插入...")
- users_data = [
- {'name': 'Test User 2', 'email': 'test2@example.com', 'age': 22},
- {'name': 'Test User 3', 'email': 'test3@example.com', 'age': 23},
- {'name': 'Test User 4', 'email': 'test4@example.com', 'age': 24}
- ]
- affected_rows = mysql_db.insert_many('test_users', users_data)
- self.assertEqual(affected_rows, 3)
- print(f"批量插入了 {affected_rows} 条记录")
- def test_06_pagination(self):
- """测试分页查询"""
- print("\n测试分页查询...")
- result = mysql_db.paginate(
- 'test_users',
- page=1,
- page_size=2,
- where="name LIKE %s",
- where_params=('Test%',),
- order_by='id ASC'
- )
- self.assertIn('data', result)
- self.assertIn('pagination', result)
- self.assertEqual(len(result['data']), 2)
- self.assertEqual(result['pagination']['current_page'], 1)
- self.assertEqual(result['pagination']['page_size'], 2)
- print(f"分页查询结果: 当前页 {result['pagination']['current_page']}, "
- f"总记录数 {result['pagination']['total_count']}")
- def test_07_sorting(self):
- """测试排序查询"""
- print("\n测试排序查询...")
- # 单字段排序
- users = mysql_db.select_with_sort(
- 'test_users',
- where="name LIKE %s",
- where_params=('Test%',),
- sort_field='age',
- sort_order='DESC',
- limit=3
- )
- self.assertGreaterEqual(len(users), 1)
- print(f"按年龄降序查询到 {len(users)} 个用户")
- # 验证排序结果
- if len(users) > 1:
- self.assertGreaterEqual(users[0]['age'], users[1]['age'])
- def test_08_aggregation(self):
- """测试聚合查询"""
- print("\n测试聚合查询...")
- agg_result = mysql_db.aggregate(
- 'test_users',
- {
- 'total_count': 'COUNT(*)',
- 'avg_age': 'AVG(age)',
- 'max_age': 'MAX(age)',
- 'min_age': 'MIN(age)'
- },
- where="name LIKE %s",
- where_params=('Test%',)
- )
- self.assertEqual(len(agg_result), 1)
- result = agg_result[0]
- self.assertGreater(result['total_count'], 0)
- self.assertIsNotNone(result['avg_age'])
- self.assertIsNotNone(result['max_age'])
- self.assertIsNotNone(result['min_age'])
- print(f"聚合查询结果: {result}")
- def test_09_search(self):
- """测试模糊搜索"""
- print("\n测试模糊搜索...")
- results = mysql_db.search(
- 'test_users',
- ['name', 'email'],
- 'Test',
- limit=10
- )
- self.assertGreaterEqual(len(results), 1)
- print(f"搜索到 {len(results)} 条记录")
- # 验证搜索结果
- for result in results:
- self.assertTrue(
- 'Test' in result['name'] or 'Test' in (result['email'] or '')
- )
- def test_10_transaction(self):
- """测试事务操作"""
- print("\n测试事务操作...")
- try:
- with mysql_db.transaction():
- # 在事务中插入数据
- user_id = mysql_db.insert('test_users', {
- 'name': 'Test Transaction User',
- 'email': 'trans@example.com',
- 'age': 30
- })
- # 更新刚插入的数据
- mysql_db.update('test_users', {'age': 31}, 'id = %s', (user_id,))
- print(f"事务中插入并更新用户,ID: {user_id}")
- # 验证事务提交后的数据
- user = mysql_db.select_one('test_users', where='id = %s', where_params=(user_id,))
- self.assertIsNotNone(user)
- self.assertEqual(user['age'], 31)
- except Exception as e:
- self.fail(f"事务测试失败: {e}")
- def test_11_transaction_rollback(self):
- """测试事务回滚"""
- print("\n测试事务回滚...")
- initial_count = mysql_db.count('test_users')
- print(f"事务前记录数: {initial_count}")
- try:
- with mysql_db.transaction() as conn:
- # 在事务中插入一个用户,传递连接参数
- user_id = mysql_db.insert('test_users', {
- 'name': 'Test Rollback User',
- 'email': 'rollback@example.com',
- 'age': 35
- }, connection=conn)
- print(f"事务中插入用户ID: {user_id}")
- # 人为抛出异常触发回滚
- raise ValueError("测试回滚")
- except ValueError:
- # 这是预期的异常
- print("捕获到预期的异常,事务应该已回滚")
- # 等待一下确保事务完全处理
- import time
- time.sleep(0.1)
- # 验证回滚后数据没有增加
- final_count = mysql_db.count('test_users')
- print(f"事务后记录数: {final_count}")
- if initial_count != final_count:
- # 如果计数不匹配,查看是否是我们插入的测试数据
- rollback_user = mysql_db.select_one('test_users', where="name = %s", where_params=('Test Rollback User',))
- if rollback_user:
- print("❌ 事务回滚失败,找到了应该被回滚的数据")
- # 手动清理这条数据
- mysql_db.delete('test_users', 'name = %s', ('Test Rollback User',))
- self.fail("事务回滚失败")
- else:
- print("✅ 虽然计数不同,但回滚用户确实不存在,可能是其他并发操作")
- print("事务回滚测试通过")
- def test_12_connection_pool(self):
- """测试连接池"""
- print("\n测试连接池...")
- # 获取连接池状态
- status = mysql_db.pool.get_pool_status()
- self.assertIn('active_connections', status)
- self.assertIn('pool_size', status)
- self.assertIn('max_connections', status)
- print(f"连接池状态: {status}")
- # 测试并发获取连接
- connections = []
- try:
- for i in range(3):
- conn = mysql_db.pool.get_connection()
- connections.append(conn)
- # 验证连接可用性
- for conn in connections:
- conn.ping()
- print("连接池并发测试通过")
- finally:
- # 归还连接
- for conn in connections:
- mysql_db.pool.return_connection(conn)
- def test_13_error_handling(self):
- """测试错误处理"""
- print("\n测试错误处理...")
- # 测试查询不存在的表
- with self.assertRaises(Exception): # 捕获任何异常,因为可能是pymysql原生异常
- mysql_db.select('non_existent_table')
- # 测试插入空数据
- with self.assertRaises(ValueError):
- mysql_db.insert('test_users', {})
- print("错误处理测试通过")
- def test_14_performance(self):
- """性能测试"""
- print("\n性能测试...")
- # 测试批量插入性能
- start_time = time.time()
- batch_data = []
- for i in range(100):
- batch_data.append({
- 'name': f'Perf Test User {i}',
- 'email': f'perf{i}@example.com',
- 'age': 20 + (i % 30)
- })
- mysql_db.insert_many('test_users', batch_data)
- end_time = time.time()
- execution_time = end_time - start_time
- print(f"批量插入100条记录耗时: {execution_time:.4f}秒")
- self.assertLess(execution_time, 5.0) # 应该在5秒内完成
- # 清理性能测试数据
- mysql_db.delete('test_users', 'name LIKE %s', ('Perf Test User%',))
- def test_15_cleanup(self):
- """清理测试数据"""
- print("\n清理额外的测试数据...")
- # 删除事务测试用户
- mysql_db.delete('test_users', 'name = %s', ('Test Transaction User',))
- print("清理完成")
- def run_tests():
- """运行测试套件"""
- # 创建测试套件
- test_suite = unittest.TestLoader().loadTestsFromTestCase(TestMySQLUtils)
- # 运行测试
- runner = unittest.TextTestRunner(verbosity=2)
- result = runner.run(test_suite)
- # 输出测试结果
- if result.wasSuccessful():
- print(f"\n✅ 所有测试通过! 运行了 {result.testsRun} 个测试")
- else:
- print(f"\n❌ 测试失败! {len(result.failures)} 个失败, {len(result.errors)} 个错误")
- return result.wasSuccessful()
- if __name__ == '__main__':
- try:
- success = run_tests()
- exit(0 if success else 1)
- except Exception as e:
- print(f"测试运行出错: {e}")
- exit(1)
|