update_online_dnn_version.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. # -*- coding: utf-8 -*-
  2. # This file is auto-generated, don't edit it. Thanks.
  3. import json
  4. import os
  5. import sys
  6. import time
  7. from typing import List
  8. from alibabacloud_eas20210701.client import Client as eas20210701Client
  9. from alibabacloud_credentials.client import Client as CredentialClient
  10. from alibabacloud_tea_openapi import models as open_api_models
  11. from alibabacloud_eas20210701 import models as eas_20210701_models
  12. from alibabacloud_tea_util import models as util_models
  13. from alibabacloud_tea_util.client import Client as UtilClient
  14. import alibabacloud_oss_v2 as oss
  15. SERVICE_NAME = "ad_rank_dnn_v11_easyrec_test"
  16. ACCESS_KEY_ID = "LTAI5tFGqgC8f3mh1fRCrAEy"
  17. ACCESS_KEY_SECRET = "XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
  18. class EASClient:
  19. def __init__(self):
  20. pass
  21. @staticmethod
  22. def create_client() -> eas20210701Client:
  23. """
  24. 使用凭据初始化账号Client
  25. @return: Client
  26. @throws Exception
  27. """
  28. config = open_api_models.Config(
  29. access_key_id=ACCESS_KEY_ID,
  30. access_key_secret=ACCESS_KEY_SECRET
  31. )
  32. # Endpoint 请参考 https://api.aliyun.com/product/eas
  33. config.endpoint = f'pai-eas.cn-hangzhou.aliyuncs.com'
  34. return eas20210701Client(config)
  35. @staticmethod
  36. def update_online_model_path(model_path_config: str):
  37. client = EASClient.create_client()
  38. update_service_request = eas_20210701_models.UpdateServiceRequest(
  39. update_type='merge',
  40. body=model_path_config
  41. )
  42. runtime = util_models.RuntimeOptions()
  43. headers = {}
  44. try:
  45. # 复制代码运行请自行打印 API 的返回值
  46. res = client.update_service_with_options('cn-hangzhou', SERVICE_NAME, update_service_request,
  47. headers, runtime)
  48. return res
  49. except Exception as error:
  50. # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
  51. # 错误 message
  52. print(error.message)
  53. # 诊断地址
  54. print(error.data.get("Recommend"))
  55. UtilClient.assert_as_string(error.message)
  56. def exist_oss_directory_path(directory_path):
  57. if not directory_path.endswith('/'):
  58. directory_path += '/'
  59. oss_config = oss.config.load_default()
  60. oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider(
  61. access_key_id=ACCESS_KEY_ID, access_key_secret=ACCESS_KEY_SECRET
  62. )
  63. oss_config.region = "cn-hangzhou"
  64. client = oss.Client(oss_config)
  65. bucket_name = "art-recommend"
  66. # 列举指定前缀下的对象
  67. options = oss.ListObjectsRequest(bucket_name)
  68. options.prefix = directory_path # 设置前缀为目录路径
  69. options.max_keys = 1 # 只需要判断是否存在文件,所以只需要获取一个对象
  70. result = client.list_objects(options)
  71. if result.contents is None:
  72. return False
  73. return len(result.contents) > 0
  74. def update_trained_cids_pointer(model_name=None, dt_version=None):
  75. # 如均为空,则从工作流中获取
  76. if not (model_name and dt_version):
  77. # 不允许其中一个为空
  78. raise Exception("model_name 和 dt_version 必须同时提供")
  79. model_version = {}
  80. model_version['modelName'] = f"model_name={model_name}"
  81. model_version['dtVersion'] = f"dt_version={dt_version}"
  82. model_version['timestamp'] = int(time.time())
  83. bucket_name = "art-recommend"
  84. object_key = "fengzhoutian/pai_model_trained_cids/model_version.json"
  85. oss_config = oss.config.load_default()
  86. oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider(
  87. access_key_id=ACCESS_KEY_ID, access_key_secret=ACCESS_KEY_SECRET
  88. )
  89. oss_config.region = "cn-hangzhou"
  90. client = oss.Client(oss_config)
  91. res = client.put_object(oss.PutObjectRequest(
  92. bucket=bucket_name,
  93. key=object_key,
  94. body=json.dumps(model_version, ensure_ascii=False, indent=4).encode('utf-8')
  95. ))
  96. return res
  97. def update_online_dnn_version(model_name: str, dt: str):
  98. if not (model_name and dt):
  99. print('请输入 model_name 和 dt 参数')
  100. return
  101. model_path = f'fengzhoutian/pai_model_exports/ad_rank_dnn_{model_name}/{dt}'
  102. cids_path = f'fengzhoutian/pai_model_trained_cids/model_name={model_name}/dt_version={dt}'
  103. model_exist = exist_oss_directory_path(model_path)
  104. cids_exist = exist_oss_directory_path(cids_path)
  105. if model_exist and cids_exist:
  106. oss_model_path = 'oss://art-recommend/' + model_path
  107. model_path_config = {
  108. "model_path": oss_model_path
  109. }
  110. print(json.dumps(model_path_config))
  111. online_model_update_res = EASClient.update_online_model_path(json.dumps(model_path_config))
  112. if online_model_update_res.status_code == 200:
  113. cid_update_res = update_trained_cids_pointer(model_name, dt)
  114. if cid_update_res.status_code == 200:
  115. print('模型更新成功,cid配置更新成功')
  116. else:
  117. print("cid文件更新失败,请手动配置更新")
  118. else:
  119. print('在线模型更新失败,不再更新cid文件')
  120. else:
  121. # 输出哪个文件不存在
  122. missing_files = []
  123. if not model_exist:
  124. missing_files.append("模型文件")
  125. if not cids_exist:
  126. missing_files.append("CID文件")
  127. print(f"以下文件不存在: {', '.join(missing_files)}")
  128. if __name__ == '__main__':
  129. if len(sys.argv) < 3:
  130. print('请输入 model_name 和 dt 参数')
  131. sys.exit(1)
  132. update_online_dnn_version(sys.argv[1], sys.argv[2])