update_online_dnn_version.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # -*- coding: utf-8 -*-
  2. # This file is auto-generated, don't edit it. Thanks.
  3. import json
  4. import sys
  5. import time
  6. import alibabacloud_oss_v2 as oss
  7. from alibabacloud_eas20210701 import models as eas_20210701_models
  8. from alibabacloud_eas20210701.client import Client as eas20210701Client
  9. from alibabacloud_tea_openapi import models as open_api_models
  10. from alibabacloud_tea_util import models as util_models
  11. from alibabacloud_tea_util.client import Client as UtilClient
  12. SERVICE_NAME = "ad_rank_dnn_v11_easyrec"
  13. ACCESS_KEY_ID = "LTAI5tFGqgC8f3mh1fRCrAEy"
  14. ACCESS_KEY_SECRET = "XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
  15. class EASClient:
  16. def __init__(self):
  17. pass
  18. @staticmethod
  19. def create_client() -> eas20210701Client:
  20. """
  21. 使用凭据初始化账号Client
  22. @return: Client
  23. @throws Exception
  24. """
  25. config = open_api_models.Config(
  26. access_key_id=ACCESS_KEY_ID,
  27. access_key_secret=ACCESS_KEY_SECRET
  28. )
  29. # Endpoint 请参考 https://api.aliyun.com/product/eas
  30. config.endpoint = f'pai-eas.cn-hangzhou.aliyuncs.com'
  31. return eas20210701Client(config)
  32. @staticmethod
  33. def update_online_model_path(model_path_config: str):
  34. client = EASClient.create_client()
  35. update_service_request = eas_20210701_models.UpdateServiceRequest(
  36. update_type='merge',
  37. body=model_path_config
  38. )
  39. runtime = util_models.RuntimeOptions()
  40. headers = {}
  41. try:
  42. # 复制代码运行请自行打印 API 的返回值
  43. res = client.update_service_with_options('cn-hangzhou', SERVICE_NAME, update_service_request,
  44. headers, runtime)
  45. return res
  46. except Exception as error:
  47. print(error.message)
  48. # 诊断地址
  49. print(error.data.get("Recommend"))
  50. raise error
  51. @staticmethod
  52. def get_online_model_detail():
  53. client = EASClient.create_client()
  54. runtime = util_models.RuntimeOptions()
  55. headers = {}
  56. try:
  57. res = client.describe_service_with_options('cn-hangzhou', SERVICE_NAME, headers, runtime)
  58. return res.body.to_map()
  59. except Exception as error:
  60. print(error.message)
  61. # 诊断地址
  62. print(error.data.get("Recommend"))
  63. raise error
  64. def exist_oss_directory_path(directory_path):
  65. if not directory_path.endswith('/'):
  66. directory_path += '/'
  67. oss_config = oss.config.load_default()
  68. oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider(
  69. access_key_id=ACCESS_KEY_ID, access_key_secret=ACCESS_KEY_SECRET
  70. )
  71. oss_config.region = "cn-hangzhou"
  72. client = oss.Client(oss_config)
  73. bucket_name = "art-recommend"
  74. # 列举指定前缀下的对象
  75. options = oss.ListObjectsRequest(bucket_name)
  76. options.prefix = directory_path # 设置前缀为目录路径
  77. options.max_keys = 1 # 只需要判断是否存在文件,所以只需要获取一个对象
  78. result = client.list_objects(options)
  79. if result.contents is None:
  80. return False
  81. return len(result.contents) > 0
  82. def update_trained_cids_pointer(model_name=None, dt_version=None):
  83. # 如均为空,则从工作流中获取
  84. if not (model_name and dt_version):
  85. # 不允许其中一个为空
  86. raise Exception("model_name 和 dt_version 必须同时提供")
  87. model_version = {}
  88. model_version['modelName'] = f"model_name={model_name}"
  89. model_version['dtVersion'] = f"dt_version={dt_version}"
  90. model_version['timestamp'] = int(time.time())
  91. bucket_name = "art-recommend"
  92. object_key = "fengzhoutian/pai_model_trained_cids/model_version.json"
  93. oss_config = oss.config.load_default()
  94. oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider(
  95. access_key_id=ACCESS_KEY_ID, access_key_secret=ACCESS_KEY_SECRET
  96. )
  97. oss_config.region = "cn-hangzhou"
  98. client = oss.Client(oss_config)
  99. res = client.put_object(oss.PutObjectRequest(
  100. bucket=bucket_name,
  101. key=object_key,
  102. body=json.dumps(model_version, ensure_ascii=False, indent=4).encode('utf-8')
  103. ))
  104. return res
  105. def update_online_dnn_version(model_name: str, dt: str):
  106. if not (model_name and dt):
  107. print('请输入 model_name 和 dt 参数')
  108. return
  109. model_path = f'fengzhoutian/pai_model_exports/ad_rank_dnn_{model_name}/{dt}'
  110. cids_path = f'fengzhoutian/pai_model_trained_cids/model_name={model_name}/dt_version={dt}'
  111. model_exist = exist_oss_directory_path(model_path)
  112. cids_exist = exist_oss_directory_path(cids_path)
  113. if model_exist and cids_exist:
  114. oss_model_path = 'oss://art-recommend/' + model_path
  115. model_path_config = {
  116. "model_path": oss_model_path
  117. }
  118. print(json.dumps(model_path_config))
  119. online_model_update_res = EASClient.update_online_model_path(json.dumps(model_path_config))
  120. if online_model_update_res.status_code == 200:
  121. while True:
  122. time.sleep(15)
  123. status = EASClient.get_online_model_detail()['Status']
  124. if status == 'Running':
  125. break
  126. cid_update_res = update_trained_cids_pointer(model_name, dt)
  127. if cid_update_res.status_code == 200:
  128. print('模型更新成功,cid配置更新成功')
  129. else:
  130. print("cid文件更新失败,请手动配置更新")
  131. else:
  132. print('在线模型更新失败,不再更新cid文件')
  133. else:
  134. # 输出哪个文件不存在
  135. missing_files = []
  136. if not model_exist:
  137. missing_files.append("模型文件")
  138. if not cids_exist:
  139. missing_files.append("CID文件")
  140. print(f"以下文件不存在: {', '.join(missing_files)}")
  141. if __name__ == '__main__':
  142. if len(sys.argv) < 3:
  143. print('请输入 model_name 和 dt 参数')
  144. sys.exit(1)
  145. update_online_dnn_version(sys.argv[1], sys.argv[2])