import logging
import sys
import time
import docker
import gateway_config
import alb_utils
import requests

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
                    datefmt='%a, %d %b %Y %H:%M:%S')

health_instances = []

def server_health_check(ecs_client, instance_id):
    """
    服务健康检查
    :param ecs_client: 客户端连接
    :param instance_id: instanceId
    :return:
    """
    global health_instances
    ip_address = alb_utils.get_ip_address(ecs_client=ecs_client, instance_id=instance_id)
    while True:
        health_check_url = f"http://{ip_address}:9000/healthcheck"
        try:
            http_code = requests.get(health_check_url).status_code
        except:
            logging.info(f"images is downloading ip:{ip_address}")
            http_code = 0

        if http_code == 200:
            health_instances.append((instance_id, ip_address))
            logging.info(f"health check success, instance: {instance_id}/{ip_address}")
            break
        else:
            time.sleep(10)

def server_restart(alb_client, ecs_client, instance_id, image_name, port):
    try:
        logging.info(f"Restarting instance: {instance_id}")
        # 获取ip
        ip_address = alb_utils.get_ip_address(ecs_client=ecs_client, instance_id=instance_id)
        logging.info(f"IP address: {instance_id}/{ip_address}")

        # 摘流量
        alb_utils.update_server_group_servers_attribute(alb_client=alb_client,
                                                    server_group_id_list=gateway_config.server_group_id_list,
                                                    instance_id_list=[instance_id],
                                                    weight_list=[(0, 20)],
                                                    port=port)
        logging.info(f"Set weight to 0 for instance: {instance_id}")

        # 连接 Docker 客户端并重启容器
        docker_client = docker.DockerClient(base_url=f'tcp://{ip_address}:2375', timeout=60)
        container = docker_client.containers.get(image_name)
        container.restart()
        logging.info("Docker restart finished.")

        # 探活
        server_health_check(ecs_client, instance_id)
        time.sleep(20)
        logging.info("Health check finished.")

        # 设置权重
        add_weight_list = [(10, 5), (20, 5), (40, 5), (60, 5), (80, 5), (100, 5)]
        alb_utils.update_server_group_servers_attribute(alb_client=alb_client,
                                                    server_group_id_list=gateway_config.server_group_id_list,
                                                    instance_id_list=[instance_id],
                                                    weight_list=add_weight_list,
                                                    port=port)
        logging.info(f"Server restart finished, instance: {instance_id}/{ip_address}")
    except Exception as e:
        logging.error(f"Server restart failed, instance: {instance_id}")
        logging.error(e)


def main():
    try:
        ecs_client = alb_utils.connect_client(access_key_id=gateway_config.ecs_client_params['access_key_id'],
                                          access_key_secret=gateway_config.ecs_client_params['access_key_secret'],
                                          region_id=gateway_config.ecs_client_params['region_id'])

        alb_client = alb_utils.connect_alb_client(
            access_key_id=gateway_config.alb_client_params['access_key_id'],
            access_key_secret=gateway_config.alb_client_params['access_key_secret'],
            endpoint=gateway_config.alb_client_params['endpoint']
        )

        # 获取 ALB 下所有机器
        online_instance_ids = alb_utils.list_server_group_servers(alb_client=alb_client,
                                                              server_group_id=gateway_config.server_group_id_list[0])
        online_instance_count = len(online_instance_ids)
        logging.info(f"Online instance count: {online_instance_count}.")
        logging.info(f"Online instance ids: {online_instance_ids}")

        # 逐台重启
        image_name = 'piaoquan-gateway'
        for i, instance_id in enumerate(online_instance_ids):
            server_restart(alb_client=alb_client, ecs_client=ecs_client, instance_id=instance_id, image_name=image_name, port=gateway_config.port)
            logging.info(f"Restart progress: {i + 1}/{online_instance_count}")
        logging.info("All servers restarted successfully!")
    except Exception as e:
        logging.error("An error occurred during the main execution.")
        logging.error(e)


if __name__ == '__main__':
    main()