import logging
import time
import docker
import longvideoapi_config
import utils
from longvideoapi_update_k import server_health_check

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
                    datefmt='%a, %d %b %Y %H:%M:%S')

def server_restart(alb_client, ecs_client, instance_id, image_name, port):
    try:
        logging.info(f"Restarting instance: {instance_id}")
        # 获取ip
        ip_address = utils.get_ip_address(ecs_client=ecs_client, instance_id=instance_id)
        logging.info(f"IP address: {instance_id}/{ip_address}")

        # 摘流量
        utils.update_server_group_servers_attribute(alb_client=alb_client,
                                                    server_group_id_list=longvideoapi_config.server_group_id_list,
                                                    instance_id_list=[instance_id],
                                                    weight_list=[(0, 20)],
                                                    port=port)
        logging.info(f"Set weight to 0 for instance: {instance_id}")

        # 连接 Docker 客户端并重启容器
        docker_client = docker.DockerClient(base_url=f'tcp://{ip_address}:2375', timeout=60)
        container = docker_client.containers.get(image_name)
        container.restart()
        logging.info("Docker restart finished.")

        # 探活
        server_health_check(ecs_client, instance_id)
        time.sleep(20)
        logging.info("Health check finished.")

        # 设置权重
        add_weight_list = [(10, 5), (20, 5), (40, 5), (60, 5), (80, 5), (100, 5)]
        # add_weight_list = [(10, 10), (20, 10), (40, 10), (60, 10), (80, 10), (100, 10)]
        utils.update_server_group_servers_attribute(alb_client=alb_client,
                                                    server_group_id_list=longvideoapi_config.server_group_id_list,
                                                    instance_id_list=[instance_id],
                                                    weight_list=add_weight_list,
                                                    port=port)
        logging.info(f"Server restart finished, instance: {instance_id}/{ip_address}")
    except Exception as e:
        logging.error(f"Server restart failed, instance: {instance_id}")
        logging.error(e)


def main():
    try:
        ecs_client = utils.connect_client(access_key_id=longvideoapi_config.ecs_client_params['access_key_id'],
                                          access_key_secret=longvideoapi_config.ecs_client_params['access_key_secret'],
                                          region_id=longvideoapi_config.ecs_client_params['region_id'])

        alb_client = utils.connect_alb_client(
            access_key_id=longvideoapi_config.alb_client_params['access_key_id'],
            access_key_secret=longvideoapi_config.alb_client_params['access_key_secret'],
            endpoint=longvideoapi_config.alb_client_params['endpoint']
        )

        # 获取 ALB 下所有机器
        online_instance_ids = utils.list_server_group_servers(alb_client=alb_client,
                                                              server_group_id=longvideoapi_config.server_group_id_list[0])
        online_instance_count = len(online_instance_ids)
        logging.info(f"Online instance count: {online_instance_count}.")
        logging.info(f"Online instance ids: {online_instance_ids}")

        # 逐台重启
        image_name = 'longvideoapi'
        for i, instance_id in enumerate(online_instance_ids):
            server_restart(alb_client=alb_client, ecs_client=ecs_client, instance_id=instance_id, image_name=image_name, port=longvideoapi_config.port)
            logging.info(f"Restart progress: {i + 1}/{online_instance_count}")
        logging.info("All servers restarted successfully!")
    except Exception as e:
        logging.error("An error occurred during the main execution.")
        logging.error(e)


if __name__ == '__main__':
    main()