Selaa lähdekoodia

增加更新dnn在线配置脚本

xueyiming 1 päivä sitten
vanhempi
commit
8b08930567
1 muutettua tiedostoa jossa 148 lisäystä ja 0 poistoa
  1. 148 0
      ad/update_online_dnn_version.py

+ 148 - 0
ad/update_online_dnn_version.py

@@ -0,0 +1,148 @@
+# -*- coding: utf-8 -*-
+# This file is auto-generated, don't edit it. Thanks.
+import json
+import os
+import sys
+import time
+
+from typing import List
+
+from alibabacloud_eas20210701.client import Client as eas20210701Client
+from alibabacloud_credentials.client import Client as CredentialClient
+from alibabacloud_tea_openapi import models as open_api_models
+from alibabacloud_eas20210701 import models as eas_20210701_models
+from alibabacloud_tea_util import models as util_models
+from alibabacloud_tea_util.client import Client as UtilClient
+import alibabacloud_oss_v2 as oss
+
+SERVICE_NAME = "ad_rank_dnn_v11_easyrec_test"
+ACCESS_KEY_ID = "LTAI5tFGqgC8f3mh1fRCrAEy"
+ACCESS_KEY_SECRET = "XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
+
+
+class EASClient:
+    def __init__(self):
+        pass
+
+    @staticmethod
+    def create_client() -> eas20210701Client:
+        """
+        使用凭据初始化账号Client
+        @return: Client
+        @throws Exception
+        """
+        config = open_api_models.Config(
+            access_key_id=ACCESS_KEY_ID,
+            access_key_secret=ACCESS_KEY_SECRET
+        )
+        # Endpoint 请参考 https://api.aliyun.com/product/eas
+        config.endpoint = f'pai-eas.cn-hangzhou.aliyuncs.com'
+        return eas20210701Client(config)
+
+    @staticmethod
+    def update_online_model_path(model_path_config: str):
+        client = EASClient.create_client()
+        update_service_request = eas_20210701_models.UpdateServiceRequest(
+            update_type='merge',
+            body=model_path_config
+        )
+        runtime = util_models.RuntimeOptions()
+        headers = {}
+        try:
+            # 复制代码运行请自行打印 API 的返回值
+            res = client.update_service_with_options('cn-hangzhou', SERVICE_NAME, update_service_request,
+                                                     headers, runtime)
+            return res
+        except Exception as error:
+            # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
+            # 错误 message
+            print(error.message)
+            # 诊断地址
+            print(error.data.get("Recommend"))
+            UtilClient.assert_as_string(error.message)
+
+
+def exist_oss_directory_path(directory_path):
+    if not directory_path.endswith('/'):
+        directory_path += '/'
+    oss_config = oss.config.load_default()
+    oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider(
+        access_key_id=ACCESS_KEY_ID, access_key_secret=ACCESS_KEY_SECRET
+    )
+    oss_config.region = "cn-hangzhou"
+    client = oss.Client(oss_config)
+    bucket_name = "art-recommend"
+    # 列举指定前缀下的对象
+    options = oss.ListObjectsRequest(bucket_name)
+    options.prefix = directory_path  # 设置前缀为目录路径
+    options.max_keys = 1  # 只需要判断是否存在文件,所以只需要获取一个对象
+    result = client.list_objects(options)
+    if result.contents is None:
+        return False
+    return len(result.contents) > 0
+
+
+def update_trained_cids_pointer(model_name=None, dt_version=None):
+    # 如均为空,则从工作流中获取
+    if not (model_name and dt_version):
+        # 不允许其中一个为空
+        raise Exception("model_name 和 dt_version 必须同时提供")
+    model_version = {}
+    model_version['modelName'] = f"model_name={model_name}"
+    model_version['dtVersion'] = f"dt_version={dt_version}"
+    model_version['timestamp'] = int(time.time())
+    bucket_name = "art-recommend"
+    object_key = "fengzhoutian/pai_model_trained_cids/model_version.json"
+
+    oss_config = oss.config.load_default()
+    oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider(
+        access_key_id=ACCESS_KEY_ID, access_key_secret=ACCESS_KEY_SECRET
+    )
+    oss_config.region = "cn-hangzhou"
+    client = oss.Client(oss_config)
+    res = client.put_object(oss.PutObjectRequest(
+        bucket=bucket_name,
+        key=object_key,
+        body=json.dumps(model_version, ensure_ascii=False, indent=4).encode('utf-8')
+    ))
+    return res
+
+
+def update_online_dnn_version(model_name: str, dt: str):
+    if not (model_name and dt):
+        print('请输入 model_name 和 dt 参数')
+        return
+    model_path = f'fengzhoutian/pai_model_exports/ad_rank_dnn_{model_name}/{dt}'
+    cids_path = f'fengzhoutian/pai_model_trained_cids/model_name={model_name}/dt_version={dt}'
+    model_exist = exist_oss_directory_path(model_path)
+    cids_exist = exist_oss_directory_path(cids_path)
+    if model_exist and cids_exist:
+        oss_model_path = 'oss://art-recommend/' + model_path
+        model_path_config = {
+            "model_path": oss_model_path
+        }
+        print(json.dumps(model_path_config))
+        online_model_update_res = EASClient.update_online_model_path(json.dumps(model_path_config))
+        if online_model_update_res.status_code == 200:
+            cid_update_res = update_trained_cids_pointer(model_name, dt)
+            if cid_update_res.status_code == 200:
+                print('模型更新成功,cid配置更新成功')
+            else:
+                print("cid文件更新失败,请手动配置更新")
+        else:
+            print('在线模型更新失败,不再更新cid文件')
+    else:
+        # 输出哪个文件不存在
+        missing_files = []
+        if not model_exist:
+            missing_files.append("模型文件")
+        if not cids_exist:
+            missing_files.append("CID文件")
+        print(f"以下文件不存在: {', '.join(missing_files)}")
+
+
+if __name__ == '__main__':
+    if len(sys.argv) < 3:
+        print('请输入 model_name 和 dt 参数')
+        sys.exit(1)
+    update_online_dnn_version(sys.argv[1], sys.argv[2])