Bladeren bron

Update pai_flow_operator

StrayWarrior 1 dag geleden
bovenliggende
commit
20cbab30b5
1 gewijzigde bestanden met toevoegingen van 70 en 23 verwijderingen
  1. 70 23
      ad/pai_flow_operator.py

+ 70 - 23
ad/pai_flow_operator.py

@@ -17,6 +17,7 @@ from alibabacloud_paiflow20210202.client import Client as PAIFlow20210202Client
 from datetime import datetime, timedelta
 from odps import ODPS
 from ad_monitor_util import _monitor
+import alibabacloud_oss_v2 as oss
 
 target_names = {
     '样本shuffle',
@@ -31,7 +32,9 @@ target_names = {
     '预测结果对比'
 }
 
-experiment_id = "draft-wqgkag89sbh9v1zvut"
+EXPERIMENT_ID = "draft-wqgkag89sbh9v1zvut"
+ACCESS_KEY_ID = "LTAI5tFGqgC8f3mh1fRCrAEy"
+ACCESS_KEY_SECRET = "XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
 
 MAX_RETRIES = 3
 
@@ -56,8 +59,8 @@ def retry(func):
 
 def get_odps_instance(project):
     odps = ODPS(
-        access_id='LTAIWYUujJAm7CbH',
-        secret_access_key='RfSjdiWwED1sGFlsjXv0DlfTnZTG1P',
+        access_id=ACCESS_KEY_ID,
+        secret_access_key=ACCESS_KEY_SECRET,
         project=project,
         endpoint='http://service.cn.maxcompute.aliyun.com/api',
     )
@@ -207,8 +210,8 @@ class PAIClient:
         # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
         # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
         config = open_api_models.Config(
-            access_key_id="LTAI5tFGqgC8f3mh1fRCrAEy",
-            access_key_secret="XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
+            access_key_id=ACCESS_KEY_ID,
+            access_key_secret=ACCESS_KEY_SECRET
         )
         # Endpoint 请参考 https://api.aliyun.com/product/PaiStudio
         config.endpoint = f'pai.cn-hangzhou.aliyuncs.com'
@@ -224,8 +227,8 @@ class PAIClient:
         # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
         # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
         config = open_api_models.Config(
-            access_key_id="LTAI5tFGqgC8f3mh1fRCrAEy",
-            access_key_secret="XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
+            access_key_id=ACCESS_KEY_ID,
+            access_key_secret=ACCESS_KEY_SECRET
         )
         # Endpoint 请参考 https://api.aliyun.com/product/PaiStudio
         config.endpoint = f'pai-eas.cn-hangzhou.aliyuncs.com'
@@ -242,9 +245,9 @@ class PAIClient:
         # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
         config = open_api_models.Config(
             # 必填,请确保代码运行环境设置了环境变量 ALIBABA_CLOUD_ACCESS_KEY_ID。,
-            access_key_id="LTAI5tFGqgC8f3mh1fRCrAEy",
+            access_key_id=ACCESS_KEY_ID,
             # 必填,请确保代码运行环境设置了环境变量 ALIBABA_CLOUD_ACCESS_KEY_SECRET。,
-            access_key_secret="XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
+            access_key_secret=ACCESS_KEY_SECRET
         )
         # Endpoint 请参考 https://api.aliyun.com/product/PAIFlow
         config.endpoint = f'paiflow.cn-hangzhou.aliyuncs.com'
@@ -307,7 +310,7 @@ class PAIClient:
     def create_job(experiment_id: str, node_id: str, execute_type: str):
         client = PAIClient.create_client()
         create_job_request = pai_studio_20210202_models.CreateJobRequest()
-        create_job_request.experiment_id = experiment_id
+        create_job_request.EXPERIMENT_ID = experiment_id
         create_job_request.node_id = node_id
         create_job_request.execute_type = execute_type
         runtime = util_models.RuntimeOptions()
@@ -399,7 +402,7 @@ def get_online_model_config(service_name: str):
 
 
 def update_shuffle_flow(table):
-    draft = PAIClient.get_work_flow_draft(experiment_id)
+    draft = PAIClient.get_work_flow_draft(EXPERIMENT_ID)
     print(json.dumps(draft, ensure_ascii=False))
     content = draft['Content']
     version = draft['Version']
@@ -417,11 +420,11 @@ def update_shuffle_flow(table):
                         print("error")
                     property['value'] = new_value
     new_content = json.dumps(content_json, ensure_ascii=False)
-    PAIClient.update_experiment_content(experiment_id, new_content, version)
+    PAIClient.update_experiment_content(EXPERIMENT_ID, new_content, version)
 
 
 def update_shuffle_flow_1():
-    draft = PAIClient.get_work_flow_draft(experiment_id)
+    draft = PAIClient.get_work_flow_draft(EXPERIMENT_ID)
     print(json.dumps(draft, ensure_ascii=False))
     content = draft['Content']
     version = draft['Version']
@@ -440,7 +443,7 @@ def update_shuffle_flow_1():
                         print("error")
                     property['value'] = new_value
     new_content = json.dumps(content_json, ensure_ascii=False)
-    PAIClient.update_experiment_content(experiment_id, new_content, version)
+    PAIClient.update_experiment_content(EXPERIMENT_ID, new_content, version)
 
 
 def wait_job_end(job_id: str, check_interval=300):
@@ -459,7 +462,7 @@ def wait_job_end(job_id: str, check_interval=300):
 
 
 def get_node_dict():
-    draft = PAIClient.get_work_flow_draft(experiment_id)
+    draft = PAIClient.get_work_flow_draft(EXPERIMENT_ID)
     content = draft['Content']
     content_json = json.loads(content)
     nodes = content_json.get('nodes')
@@ -474,7 +477,7 @@ def get_node_dict():
 
 def get_job_dict():
     job_dict = {}
-    jobs_list = PAIClient.get_jobs_list(experiment_id)
+    jobs_list = PAIClient.get_jobs_list(EXPERIMENT_ID)
     for job in jobs_list['Jobs']:
         # 解析时间字符串为 datetime 对象
         if not compare_timestamp_with_today_start(job['GmtCreateTime']):
@@ -494,7 +497,7 @@ def get_job_dict():
 def update_online_flow():
     try:
         online_model_config = get_online_model_config('ad_rank_dnn_v11_easyrec')
-        draft = PAIClient.get_work_flow_draft(experiment_id)
+        draft = PAIClient.get_work_flow_draft(EXPERIMENT_ID)
         print(json.dumps(draft, ensure_ascii=False))
         content = draft['Content']
         version = draft['Version']
@@ -531,7 +534,7 @@ def update_online_flow():
             except KeyError:
                 raise Exception("在处理节点属性时,字典中缺少必要的键")
         new_content = json.dumps(content_json, ensure_ascii=False)
-        PAIClient.update_experiment_content(experiment_id, new_content, version)
+        PAIClient.update_experiment_content(EXPERIMENT_ID, new_content, version)
         return True
     except json.JSONDecodeError:
         raise Exception("JSON 解析错误,可能是草稿内容格式不正确")
@@ -544,7 +547,7 @@ def shuffle_table():
         node_dict = get_node_dict()
         train_node_id = node_dict['样本shuffle']
         execute_type = 'EXECUTE_FROM_HERE'
-        validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
+        validate_res = PAIClient.create_job(EXPERIMENT_ID, train_node_id, execute_type)
         validate_job_id = validate_res['JobId']
         validate_job_detail = wait_job_end(validate_job_id, 10)
         if validate_job_detail['Status'] == 'Succeeded':
@@ -578,7 +581,7 @@ def shuffle_train_model():
                 node_dict = get_node_dict()
                 train_node_id = node_dict['模型训练-样本shufle']
                 execute_type = 'EXECUTE_ONE'
-                train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
+                train_res = PAIClient.create_job(EXPERIMENT_ID, train_node_id, execute_type)
                 train_job_id = train_res['JobId']
                 train_job_detail = wait_job_end(train_job_id)
                 if train_job_detail['Status'] == 'Succeeded':
@@ -596,7 +599,7 @@ def export_model():
         node_dict = get_node_dict()
         export_node_id = node_dict['模型导出-2']
         execute_type = 'EXECUTE_ONE'
-        export_res = PAIClient.create_job(experiment_id, export_node_id, execute_type)
+        export_res = PAIClient.create_job(EXPERIMENT_ID, export_node_id, execute_type)
         export_job_id = export_res['JobId']
         export_job_detail = wait_job_end(export_job_id)
         if export_job_detail['Status'] == 'Succeeded':
@@ -613,7 +616,7 @@ def update_online_model():
         node_dict = get_node_dict()
         train_node_id = node_dict['更新EAS服务(Beta)-1']
         execute_type = 'EXECUTE_ONE'
-        train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
+        train_res = PAIClient.create_job(EXPERIMENT_ID, train_node_id, execute_type)
         train_job_id = train_res['JobId']
         train_job_detail = wait_job_end(train_job_id)
         if train_job_detail['Status'] == 'Succeeded':
@@ -630,7 +633,7 @@ def get_validate_model_data():
         node_dict = get_node_dict()
         train_node_id = node_dict['虚拟起始节点']
         execute_type = 'EXECUTE_FROM_HERE'
-        validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
+        validate_res = PAIClient.create_job(EXPERIMENT_ID, train_node_id, execute_type)
         validate_job_id = validate_res['JobId']
         validate_job_detail = wait_job_end(validate_job_id)
         if validate_job_detail['Status'] == 'Succeeded':
@@ -710,6 +713,50 @@ def validate_model_data_accuracy():
         raise Exception(error_message)
 
 
+def update_trained_cids_pointer(model_name=None, dt_version=None):
+    # 如均为空,则从工作流中获取
+    if not model_name and not dt_version:
+        draft = PAIClient.get_work_flow_draft(EXPERIMENT_ID)
+        content = draft['Content']
+        content_json = json.loads(content)
+        global_params = content_json.get('globalParams', [])
+        model_name = None
+        dt_version = None
+        for param in global_params:
+            if param.get('name') == 'model_name':
+                model_name = param.get('value')
+            if param.get('name') == 'bizdate':
+                dt_version = param.get('value')
+        if not model_name or not dt_version:
+            raise Exception("globalParams 中未找到 model_name 或 bizdate")
+    elif not (model_name and dt_version):
+        # 不允许其中一个为空
+        raise Exception("model_name 和 dt_version 必须同时提供")
+    oss_content = f"oss://art-recommend/fengzhoutian/pai_model_trained_cids/model_name={model_name}/dt_version={dt_version}"
+
+    bucket_name = "art-recommend"
+    object_key = "fengzhoutian/ad_engine_files/pai_model_trained_cids_pointer"
+
+    oss_config = oss.config.load_default()
+    oss_config.credentials_provider = oss.credentials.StaticCredentialsProvider(
+        access_key_id=ACCESS_KEY_ID, access_key_secret=ACCESS_KEY_SECRET
+    )
+    oss_config.region = "cn-hangzhou"
+    client = oss.Client(oss_config)
+    ret = client.put_object(oss.PutObjectRequest(
+        bucket=bucket_name,
+        key=object_key,
+        body=oss_content.encode('utf-8')
+    ))
+    print(f'status code: {ret.status_code},'
+          f' request id: {ret.request_id},'
+          f' content md5: {ret.content_md5},'
+          f' etag: {ret.etag},'
+          f' hash crc64: {ret.hash_crc64},'
+          f' version id: {ret.version_id},'
+          )
+
+
 if __name__ == '__main__':
     start_time = int(time.time())
     functions = [update_online_flow, shuffle_table, shuffle_train_model, export_model, get_validate_model_data]