12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970 |
- #! /usr/bin/env python
- # -*- coding: utf-8 -*-
- # vim:fenc=utf-8
- #
- # Copyright © 2024 StrayWarrior <i@straywarrior.com>
- import pymysql
- from logging_service import logger
- class MySQLManager:
- def __init__(self, config):
- self.config = config
- def select(self, sql, cursor_type=None, args=None):
- """
- sql: SQL to execute, string
- """
- conn = pymysql.connect(**self.config)
- cursor = conn.cursor(cursor_type)
- cursor.execute(sql, args)
- data = cursor.fetchall()
- # do not handle exception
- cursor.close()
- conn.close()
- return data
- def execute(self, sql, args=None):
- conn = pymysql.connect(**self.config)
- cursor = conn.cursor()
- try:
- cursor.execute(sql, args)
- affected_rows = cursor.rowcount
- conn.commit()
- return affected_rows
- except Exception as e:
- conn.rollback()
- raise e
- finally:
- conn.close()
- def batch_insert(self, table, data, columns=None, ignore=False):
- """
- table: table name, string
- data: data, list[tuple] or list[dict]
- columns: column names, list, required if data is list[tuple]
- """
- if data is None or len(data) == 0:
- return None
- conn = pymysql.connect(**self.config)
- try:
- if isinstance(data[0], dict):
- keys = data[0].keys()
- columns_str = ','.join(keys)
- placeholders_str = ','.join([f'%({key})s' for key in keys])
- else:
- if len(data[0]) != len(columns):
- raise Exception("data length != column length")
- columns_str = ','.join(columns)
- placeholders_str = ','.join(['%s'] * len(data[0]))
- ignore_keyword = 'IGNORE' if ignore else ''
- with conn.cursor() as cursor:
- sql_str = f"INSERT {ignore_keyword} INTO {table} ({columns_str}) VALUES ({placeholders_str})"
- rows = cursor.executemany(sql_str, data)
- conn.commit()
- return rows
- except pymysql.MySQLError as e:
- logger.error(f"Error in batch_insert: {e}")
- conn.rollback()
- raise e
- conn.close()
|