|
@@ -0,0 +1,148 @@
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+# This file is auto-generated, don't edit it. Thanks.
|
|
|
+import json
|
|
|
+import os
|
|
|
+import sys
|
|
|
+import time
|
|
|
+
|
|
|
+from typing import List
|
|
|
+
|
|
|
+from alibabacloud_eas20210701.client import Client as eas20210701Client
|
|
|
+from alibabacloud_credentials.client import Client as CredentialClient
|
|
|
+from alibabacloud_tea_openapi import models as open_api_models
|
|
|
+from alibabacloud_eas20210701 import models as eas_20210701_models
|
|
|
+from alibabacloud_tea_util import models as util_models
|
|
|
+from alibabacloud_tea_util.client import Client as UtilClient
|
|
|
+import alibabacloud_oss_v2 as oss
|
|
|
+
|
|
|
+SERVICE_NAME = "ad_rank_dnn_v11_easyrec_test"
|
|
|
+ACCESS_KEY_ID = "LTAI5tFGqgC8f3mh1fRCrAEy"
|
|
|
+ACCESS_KEY_SECRET = "XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
|
|
|
+
|
|
|
+
|
|
|
+class EASClient:
|
|
|
+ def __init__(self):
|
|
|
+ pass
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def create_client() -> eas20210701Client:
|
|
|
+ """
|
|
|
+ 使用凭据初始化账号Client
|
|
|
+ @return: Client
|
|
|
+ @throws Exception
|
|
|
+ """
|
|
|
+ config = open_api_models.Config(
|
|
|
+ access_key_id=ACCESS_KEY_ID,
|
|
|
+ access_key_secret=ACCESS_KEY_SECRET
|
|
|
+ )
|
|
|
+ # Endpoint 请参考 https://api.aliyun.com/product/eas
|
|
|
+ config.endpoint = f'pai-eas.cn-hangzhou.aliyuncs.com'
|
|
|
+ return eas20210701Client(config)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def update_online_model_path(model_path_config: str):
|
|
|
+ client = EASClient.create_client()
|
|
|
+ update_service_request = eas_20210701_models.UpdateServiceRequest(
|
|
|
+ update_type='merge',
|
|
|
+ body=model_path_config
|
|
|
+ )
|
|
|
+ runtime = util_models.RuntimeOptions()
|
|
|
+ headers = {}
|
|
|
+ try:
|
|
|
+ # 复制代码运行请自行打印 API 的返回值
|
|
|
+ res = client.update_service_with_options('cn-hangzhou', SERVICE_NAME, update_service_request,
|
|
|
+ headers, runtime)
|
|
|
+ return res
|
|
|
+ except Exception as error:
|
|
|
+ # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
|
|
|
+ # 错误 message
|
|
|
+ print(error.message)
|
|
|
+ # 诊断地址
|
|
|
+ print(error.data.get("Recommend"))
|
|
|
+ UtilClient.assert_as_string(error.message)
|
|
|
+
|
|
|
+
|
|
|
+def exist_oss_directory_path(directory_path):
|
|
|
+ if not directory_path.endswith('/'):
|
|
|
+ directory_path += '/'
|
|
|
+ oss_config = oss.config.load_default()
|
|
|
+ oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider(
|
|
|
+ access_key_id=ACCESS_KEY_ID, access_key_secret=ACCESS_KEY_SECRET
|
|
|
+ )
|
|
|
+ oss_config.region = "cn-hangzhou"
|
|
|
+ client = oss.Client(oss_config)
|
|
|
+ bucket_name = "art-recommend"
|
|
|
+ # 列举指定前缀下的对象
|
|
|
+ options = oss.ListObjectsRequest(bucket_name)
|
|
|
+ options.prefix = directory_path # 设置前缀为目录路径
|
|
|
+ options.max_keys = 1 # 只需要判断是否存在文件,所以只需要获取一个对象
|
|
|
+ result = client.list_objects(options)
|
|
|
+ if result.contents is None:
|
|
|
+ return False
|
|
|
+ return len(result.contents) > 0
|
|
|
+
|
|
|
+
|
|
|
+def update_trained_cids_pointer(model_name=None, dt_version=None):
|
|
|
+ # 如均为空,则从工作流中获取
|
|
|
+ if not (model_name and dt_version):
|
|
|
+ # 不允许其中一个为空
|
|
|
+ raise Exception("model_name 和 dt_version 必须同时提供")
|
|
|
+ model_version = {}
|
|
|
+ model_version['modelName'] = f"model_name={model_name}"
|
|
|
+ model_version['dtVersion'] = f"dt_version={dt_version}"
|
|
|
+ model_version['timestamp'] = int(time.time())
|
|
|
+ bucket_name = "art-recommend"
|
|
|
+ object_key = "fengzhoutian/pai_model_trained_cids/model_version.json"
|
|
|
+
|
|
|
+ oss_config = oss.config.load_default()
|
|
|
+ oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider(
|
|
|
+ access_key_id=ACCESS_KEY_ID, access_key_secret=ACCESS_KEY_SECRET
|
|
|
+ )
|
|
|
+ oss_config.region = "cn-hangzhou"
|
|
|
+ client = oss.Client(oss_config)
|
|
|
+ res = client.put_object(oss.PutObjectRequest(
|
|
|
+ bucket=bucket_name,
|
|
|
+ key=object_key,
|
|
|
+ body=json.dumps(model_version, ensure_ascii=False, indent=4).encode('utf-8')
|
|
|
+ ))
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
+def update_online_dnn_version(model_name: str, dt: str):
|
|
|
+ if not (model_name and dt):
|
|
|
+ print('请输入 model_name 和 dt 参数')
|
|
|
+ return
|
|
|
+ model_path = f'fengzhoutian/pai_model_exports/ad_rank_dnn_{model_name}/{dt}'
|
|
|
+ cids_path = f'fengzhoutian/pai_model_trained_cids/model_name={model_name}/dt_version={dt}'
|
|
|
+ model_exist = exist_oss_directory_path(model_path)
|
|
|
+ cids_exist = exist_oss_directory_path(cids_path)
|
|
|
+ if model_exist and cids_exist:
|
|
|
+ oss_model_path = 'oss://art-recommend/' + model_path
|
|
|
+ model_path_config = {
|
|
|
+ "model_path": oss_model_path
|
|
|
+ }
|
|
|
+ print(json.dumps(model_path_config))
|
|
|
+ online_model_update_res = EASClient.update_online_model_path(json.dumps(model_path_config))
|
|
|
+ if online_model_update_res.status_code == 200:
|
|
|
+ cid_update_res = update_trained_cids_pointer(model_name, dt)
|
|
|
+ if cid_update_res.status_code == 200:
|
|
|
+ print('模型更新成功,cid配置更新成功')
|
|
|
+ else:
|
|
|
+ print("cid文件更新失败,请手动配置更新")
|
|
|
+ else:
|
|
|
+ print('在线模型更新失败,不再更新cid文件')
|
|
|
+ else:
|
|
|
+ # 输出哪个文件不存在
|
|
|
+ missing_files = []
|
|
|
+ if not model_exist:
|
|
|
+ missing_files.append("模型文件")
|
|
|
+ if not cids_exist:
|
|
|
+ missing_files.append("CID文件")
|
|
|
+ print(f"以下文件不存在: {', '.join(missing_files)}")
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ if len(sys.argv) < 3:
|
|
|
+ print('请输入 model_name 和 dt 参数')
|
|
|
+ sys.exit(1)
|
|
|
+ update_online_dnn_version(sys.argv[1], sys.argv[2])
|