update_online_dnn_version.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # -*- coding: utf-8 -*-
  2. # This file is auto-generated, don't edit it. Thanks.
  3. import json
  4. import sys
  5. import time
  6. import alibabacloud_oss_v2 as oss
  7. from alibabacloud_eas20210701 import models as eas_20210701_models
  8. from alibabacloud_eas20210701.client import Client as eas20210701Client
  9. from alibabacloud_tea_openapi import models as open_api_models
  10. from alibabacloud_tea_util import models as util_models
  11. from alibabacloud_tea_util.client import Client as UtilClient
  12. SERVICE_NAME = "ad_rank_dnn_v11_easyrec"
  13. ACCESS_KEY_ID = "LTAI5tFGqgC8f3mh1fRCrAEy"
  14. ACCESS_KEY_SECRET = "XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
  15. class EASClient:
  16. def __init__(self):
  17. pass
  18. @staticmethod
  19. def create_client() -> eas20210701Client:
  20. """
  21. 使用凭据初始化账号Client
  22. @return: Client
  23. @throws Exception
  24. """
  25. config = open_api_models.Config(
  26. access_key_id=ACCESS_KEY_ID,
  27. access_key_secret=ACCESS_KEY_SECRET
  28. )
  29. # Endpoint 请参考 https://api.aliyun.com/product/eas
  30. config.endpoint = f'pai-eas.cn-hangzhou.aliyuncs.com'
  31. return eas20210701Client(config)
  32. @staticmethod
  33. def update_online_model_path(model_path_config: str):
  34. client = EASClient.create_client()
  35. update_service_request = eas_20210701_models.UpdateServiceRequest(
  36. update_type='merge',
  37. body=model_path_config
  38. )
  39. runtime = util_models.RuntimeOptions()
  40. headers = {}
  41. try:
  42. # 复制代码运行请自行打印 API 的返回值
  43. res = client.update_service_with_options('cn-hangzhou', SERVICE_NAME, update_service_request,
  44. headers, runtime)
  45. return res
  46. except Exception as error:
  47. # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
  48. # 错误 message
  49. print(error.message)
  50. # 诊断地址
  51. print(error.data.get("Recommend"))
  52. UtilClient.assert_as_string(error.message)
  53. @staticmethod
  54. def get_online_model_detail():
  55. client = EASClient.create_client()
  56. runtime = util_models.RuntimeOptions()
  57. headers = {}
  58. try:
  59. # 复制代码运行请自行打印 API 的返回值
  60. res = client.describe_service_with_options('cn-hangzhou', SERVICE_NAME, headers, runtime)
  61. return res.body.to_map()
  62. except Exception as error:
  63. # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
  64. # 错误 message
  65. print(error.message)
  66. # 诊断地址
  67. print(error.data.get("Recommend"))
  68. UtilClient.assert_as_string(error.message)
  69. def exist_oss_directory_path(directory_path):
  70. if not directory_path.endswith('/'):
  71. directory_path += '/'
  72. oss_config = oss.config.load_default()
  73. oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider(
  74. access_key_id=ACCESS_KEY_ID, access_key_secret=ACCESS_KEY_SECRET
  75. )
  76. oss_config.region = "cn-hangzhou"
  77. client = oss.Client(oss_config)
  78. bucket_name = "art-recommend"
  79. # 列举指定前缀下的对象
  80. options = oss.ListObjectsRequest(bucket_name)
  81. options.prefix = directory_path # 设置前缀为目录路径
  82. options.max_keys = 1 # 只需要判断是否存在文件,所以只需要获取一个对象
  83. result = client.list_objects(options)
  84. if result.contents is None:
  85. return False
  86. return len(result.contents) > 0
  87. def update_trained_cids_pointer(model_name=None, dt_version=None):
  88. # 如均为空,则从工作流中获取
  89. if not (model_name and dt_version):
  90. # 不允许其中一个为空
  91. raise Exception("model_name 和 dt_version 必须同时提供")
  92. model_version = {}
  93. model_version['modelName'] = f"model_name={model_name}"
  94. model_version['dtVersion'] = f"dt_version={dt_version}"
  95. model_version['timestamp'] = int(time.time())
  96. bucket_name = "art-recommend"
  97. object_key = "fengzhoutian/pai_model_trained_cids/model_version.json"
  98. oss_config = oss.config.load_default()
  99. oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider(
  100. access_key_id=ACCESS_KEY_ID, access_key_secret=ACCESS_KEY_SECRET
  101. )
  102. oss_config.region = "cn-hangzhou"
  103. client = oss.Client(oss_config)
  104. res = client.put_object(oss.PutObjectRequest(
  105. bucket=bucket_name,
  106. key=object_key,
  107. body=json.dumps(model_version, ensure_ascii=False, indent=4).encode('utf-8')
  108. ))
  109. return res
  110. def update_online_dnn_version(model_name: str, dt: str):
  111. if not (model_name and dt):
  112. print('请输入 model_name 和 dt 参数')
  113. return
  114. model_path = f'fengzhoutian/pai_model_exports/ad_rank_dnn_{model_name}/{dt}'
  115. cids_path = f'fengzhoutian/pai_model_trained_cids/model_name={model_name}/dt_version={dt}'
  116. model_exist = exist_oss_directory_path(model_path)
  117. cids_exist = exist_oss_directory_path(cids_path)
  118. if model_exist and cids_exist:
  119. oss_model_path = 'oss://art-recommend/' + model_path
  120. model_path_config = {
  121. "model_path": oss_model_path
  122. }
  123. print(json.dumps(model_path_config))
  124. online_model_update_res = EASClient.update_online_model_path(json.dumps(model_path_config))
  125. if online_model_update_res.status_code == 200:
  126. time.sleep(60)
  127. status = EASClient.get_online_model_detail()['Status']
  128. while status != 'Running':
  129. time.sleep(60)
  130. status = EASClient.get_online_model_detail()['Status']
  131. cid_update_res = update_trained_cids_pointer(model_name, dt)
  132. if cid_update_res.status_code == 200:
  133. print('模型更新成功,cid配置更新成功')
  134. else:
  135. print("cid文件更新失败,请手动配置更新")
  136. else:
  137. print('在线模型更新失败,不再更新cid文件')
  138. else:
  139. # 输出哪个文件不存在
  140. missing_files = []
  141. if not model_exist:
  142. missing_files.append("模型文件")
  143. if not cids_exist:
  144. missing_files.append("CID文件")
  145. print(f"以下文件不存在: {', '.join(missing_files)}")
  146. if __name__ == '__main__':
  147. if len(sys.argv) < 3:
  148. print('请输入 model_name 和 dt 参数')
  149. sys.exit(1)
  150. update_online_dnn_version(sys.argv[1], sys.argv[2])