123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- # -*- coding: utf-8 -*-
- # This file is auto-generated, don't edit it. Thanks.
- import json
- import sys
- import time
- import alibabacloud_oss_v2 as oss
- from alibabacloud_eas20210701 import models as eas_20210701_models
- from alibabacloud_eas20210701.client import Client as eas20210701Client
- from alibabacloud_tea_openapi import models as open_api_models
- from alibabacloud_tea_util import models as util_models
- from alibabacloud_tea_util.client import Client as UtilClient
- SERVICE_NAME = "ad_rank_dnn_v11_easyrec"
- ACCESS_KEY_ID = "LTAI5tFGqgC8f3mh1fRCrAEy"
- ACCESS_KEY_SECRET = "XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
- class EASClient:
- def __init__(self):
- pass
- @staticmethod
- def create_client() -> eas20210701Client:
- """
- 使用凭据初始化账号Client
- @return: Client
- @throws Exception
- """
- config = open_api_models.Config(
- access_key_id=ACCESS_KEY_ID,
- access_key_secret=ACCESS_KEY_SECRET
- )
- # Endpoint 请参考 https://api.aliyun.com/product/eas
- config.endpoint = f'pai-eas.cn-hangzhou.aliyuncs.com'
- return eas20210701Client(config)
- @staticmethod
- def update_online_model_path(model_path_config: str):
- client = EASClient.create_client()
- update_service_request = eas_20210701_models.UpdateServiceRequest(
- update_type='merge',
- body=model_path_config
- )
- runtime = util_models.RuntimeOptions()
- headers = {}
- try:
- # 复制代码运行请自行打印 API 的返回值
- res = client.update_service_with_options('cn-hangzhou', SERVICE_NAME, update_service_request,
- headers, runtime)
- return res
- except Exception as error:
- print(error.message)
- # 诊断地址
- print(error.data.get("Recommend"))
- raise error
- @staticmethod
- def get_online_model_detail():
- client = EASClient.create_client()
- runtime = util_models.RuntimeOptions()
- headers = {}
- try:
- res = client.describe_service_with_options('cn-hangzhou', SERVICE_NAME, headers, runtime)
- return res.body.to_map()
- except Exception as error:
- print(error.message)
- # 诊断地址
- print(error.data.get("Recommend"))
- raise error
- def exist_oss_directory_path(directory_path):
- if not directory_path.endswith('/'):
- directory_path += '/'
- oss_config = oss.config.load_default()
- oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider(
- access_key_id=ACCESS_KEY_ID, access_key_secret=ACCESS_KEY_SECRET
- )
- oss_config.region = "cn-hangzhou"
- client = oss.Client(oss_config)
- bucket_name = "art-recommend"
- # 列举指定前缀下的对象
- options = oss.ListObjectsRequest(bucket_name)
- options.prefix = directory_path # 设置前缀为目录路径
- options.max_keys = 1 # 只需要判断是否存在文件,所以只需要获取一个对象
- result = client.list_objects(options)
- if result.contents is None:
- return False
- return len(result.contents) > 0
- def update_trained_cids_pointer(model_name=None, dt_version=None):
- # 如均为空,则从工作流中获取
- if not (model_name and dt_version):
- # 不允许其中一个为空
- raise Exception("model_name 和 dt_version 必须同时提供")
- model_version = {}
- model_version['modelName'] = f"model_name={model_name}"
- model_version['dtVersion'] = f"dt_version={dt_version}"
- model_version['timestamp'] = int(time.time())
- bucket_name = "art-recommend"
- object_key = "fengzhoutian/pai_model_trained_cids/model_version.json"
- oss_config = oss.config.load_default()
- oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider(
- access_key_id=ACCESS_KEY_ID, access_key_secret=ACCESS_KEY_SECRET
- )
- oss_config.region = "cn-hangzhou"
- client = oss.Client(oss_config)
- res = client.put_object(oss.PutObjectRequest(
- bucket=bucket_name,
- key=object_key,
- body=json.dumps(model_version, ensure_ascii=False, indent=4).encode('utf-8')
- ))
- return res
- def update_online_dnn_version(model_name: str, dt: str):
- if not (model_name and dt):
- print('请输入 model_name 和 dt 参数')
- return
- model_path = f'fengzhoutian/pai_model_exports/ad_rank_dnn_{model_name}/{dt}'
- cids_path = f'fengzhoutian/pai_model_trained_cids/model_name={model_name}/dt_version={dt}'
- model_exist = exist_oss_directory_path(model_path)
- cids_exist = exist_oss_directory_path(cids_path)
- if model_exist and cids_exist:
- oss_model_path = 'oss://art-recommend/' + model_path
- model_path_config = {
- "model_path": oss_model_path
- }
- print(json.dumps(model_path_config))
- online_model_update_res = EASClient.update_online_model_path(json.dumps(model_path_config))
- if online_model_update_res.status_code == 200:
- while True:
- time.sleep(15)
- status = EASClient.get_online_model_detail()['Status']
- if status == 'Running':
- break
- cid_update_res = update_trained_cids_pointer(model_name, dt)
- if cid_update_res.status_code == 200:
- print('模型更新成功,cid配置更新成功')
- else:
- print("cid文件更新失败,请手动配置更新")
- else:
- print('在线模型更新失败,不再更新cid文件')
- else:
- # 输出哪个文件不存在
- missing_files = []
- if not model_exist:
- missing_files.append("模型文件")
- if not cids_exist:
- missing_files.append("CID文件")
- print(f"以下文件不存在: {', '.join(missing_files)}")
- if __name__ == '__main__':
- if len(sys.argv) < 3:
- print('请输入 model_name 和 dt 参数')
- sys.exit(1)
- update_online_dnn_version(sys.argv[1], sys.argv[2])
|