#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
#
# Copyright © 2024 StrayWarrior <i@straywarrior.com>

import pymysql
from logging_service import logger

class MySQLManager:
    def __init__(self, config):
        self.config = config

    def select(self, sql, cursor_type=None, args=None):
        """
        sql: SQL to execute, string
        """
        conn = pymysql.connect(**self.config)
        cursor = conn.cursor(cursor_type)
        cursor.execute(sql, args)
        data = cursor.fetchall()
        # do not handle exception
        cursor.close()
        conn.close()
        return data

    def execute(self, sql, args=None):
        conn = pymysql.connect(**self.config)
        cursor = conn.cursor()
        try:
            cursor.execute(sql, args)
            affected_rows = cursor.rowcount
            conn.commit()
            return affected_rows
        except Exception as e:
            conn.rollback()
            raise e
        finally:
            conn.close()

    def batch_insert(self, table, data, columns=None, ignore=False):
        """
        table: table name, string
        data: data, list[tuple] or list[dict]
        columns: column names, list, required if data is list[tuple]
        """
        if data is None or len(data) == 0:
            return None
        conn = pymysql.connect(**self.config)
        try:
            if isinstance(data[0], dict):
                keys = data[0].keys()
                columns_str = ','.join(keys)
                placeholders_str = ','.join([f'%({key})s' for key in keys])
            else:
                if len(data[0]) != len(columns):
                    raise Exception("data length != column length")
                columns_str = ','.join(columns)
                placeholders_str = ','.join(['%s'] * len(data[0]))
            ignore_keyword = 'IGNORE' if ignore else ''
            with conn.cursor() as cursor:
                sql_str = f"INSERT {ignore_keyword} INTO {table} ({columns_str}) VALUES ({placeholders_str})"
                rows = cursor.executemany(sql_str, data)
                conn.commit()
                return rows
        except pymysql.MySQLError as e:
            logger.error(f"Error in batch_insert: {e}")
            conn.rollback()
            raise e
        conn.close()