# -*- coding: utf-8 -*- # This file is auto-generated, don't edit it. Thanks. import json import sys import time import alibabacloud_oss_v2 as oss from alibabacloud_eas20210701 import models as eas_20210701_models from alibabacloud_eas20210701.client import Client as eas20210701Client from alibabacloud_tea_openapi import models as open_api_models from alibabacloud_tea_util import models as util_models from alibabacloud_tea_util.client import Client as UtilClient SERVICE_NAME = "ad_rank_dnn_v11_easyrec" ACCESS_KEY_ID = "LTAI5tFGqgC8f3mh1fRCrAEy" ACCESS_KEY_SECRET = "XhOjK9XmTYRhVAtf6yii4s4kZwWzvV" class EASClient: def __init__(self): pass @staticmethod def create_client() -> eas20210701Client: """ 使用凭据初始化账号Client @return: Client @throws Exception """ config = open_api_models.Config( access_key_id=ACCESS_KEY_ID, access_key_secret=ACCESS_KEY_SECRET ) # Endpoint 请参考 https://api.aliyun.com/product/eas config.endpoint = f'pai-eas.cn-hangzhou.aliyuncs.com' return eas20210701Client(config) @staticmethod def update_online_model_path(model_path_config: str): client = EASClient.create_client() update_service_request = eas_20210701_models.UpdateServiceRequest( update_type='merge', body=model_path_config ) runtime = util_models.RuntimeOptions() headers = {} try: # 复制代码运行请自行打印 API 的返回值 res = client.update_service_with_options('cn-hangzhou', SERVICE_NAME, update_service_request, headers, runtime) return res except Exception as error: print(error.message) # 诊断地址 print(error.data.get("Recommend")) raise error @staticmethod def get_online_model_detail(): client = EASClient.create_client() runtime = util_models.RuntimeOptions() headers = {} try: res = client.describe_service_with_options('cn-hangzhou', SERVICE_NAME, headers, runtime) return res.body.to_map() except Exception as error: print(error.message) # 诊断地址 print(error.data.get("Recommend")) raise error def exist_oss_directory_path(directory_path): if not directory_path.endswith('/'): directory_path += '/' oss_config = oss.config.load_default() oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider( access_key_id=ACCESS_KEY_ID, access_key_secret=ACCESS_KEY_SECRET ) oss_config.region = "cn-hangzhou" client = oss.Client(oss_config) bucket_name = "art-recommend" # 列举指定前缀下的对象 options = oss.ListObjectsRequest(bucket_name) options.prefix = directory_path # 设置前缀为目录路径 options.max_keys = 1 # 只需要判断是否存在文件,所以只需要获取一个对象 result = client.list_objects(options) if result.contents is None: return False return len(result.contents) > 0 def update_trained_cids_pointer(model_name=None, dt_version=None): # 如均为空,则从工作流中获取 if not (model_name and dt_version): # 不允许其中一个为空 raise Exception("model_name 和 dt_version 必须同时提供") model_version = {} model_version['modelName'] = f"model_name={model_name}" model_version['dtVersion'] = f"dt_version={dt_version}" model_version['timestamp'] = int(time.time()) bucket_name = "art-recommend" object_key = "fengzhoutian/pai_model_trained_cids/model_version.json" oss_config = oss.config.load_default() oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider( access_key_id=ACCESS_KEY_ID, access_key_secret=ACCESS_KEY_SECRET ) oss_config.region = "cn-hangzhou" client = oss.Client(oss_config) res = client.put_object(oss.PutObjectRequest( bucket=bucket_name, key=object_key, body=json.dumps(model_version, ensure_ascii=False, indent=4).encode('utf-8') )) return res def update_online_dnn_version(model_name: str, dt: str): if not (model_name and dt): print('请输入 model_name 和 dt 参数') return model_path = f'fengzhoutian/pai_model_exports/ad_rank_dnn_{model_name}/{dt}' cids_path = f'fengzhoutian/pai_model_trained_cids/model_name={model_name}/dt_version={dt}' model_exist = exist_oss_directory_path(model_path) cids_exist = exist_oss_directory_path(cids_path) if model_exist and cids_exist: oss_model_path = 'oss://art-recommend/' + model_path model_path_config = { "model_path": oss_model_path } print(json.dumps(model_path_config)) online_model_update_res = EASClient.update_online_model_path(json.dumps(model_path_config)) if online_model_update_res.status_code == 200: while True: time.sleep(15) status = EASClient.get_online_model_detail()['Status'] if status == 'Running': break cid_update_res = update_trained_cids_pointer(model_name, dt) if cid_update_res.status_code == 200: print('模型更新成功,cid配置更新成功') else: print("cid文件更新失败,请手动配置更新") else: print('在线模型更新失败,不再更新cid文件') else: # 输出哪个文件不存在 missing_files = [] if not model_exist: missing_files.append("模型文件") if not cids_exist: missing_files.append("CID文件") print(f"以下文件不存在: {', '.join(missing_files)}") if __name__ == '__main__': if len(sys.argv) < 3: print('请输入 model_name 和 dt 参数') sys.exit(1) update_online_dnn_version(sys.argv[1], sys.argv[2])