|
@@ -0,0 +1,619 @@
|
|
|
+# -*- coding: utf-8 -*-
|
|
|
+import os
|
|
|
+import re
|
|
|
+import sys
|
|
|
+from sre_constants import error
|
|
|
+
|
|
|
+from typing import List
|
|
|
+
|
|
|
+import time
|
|
|
+import json
|
|
|
+import pandas as pd
|
|
|
+from alibabacloud_paistudio20210202.client import Client as PaiStudio20210202Client
|
|
|
+from alibabacloud_tea_openapi import models as open_api_models
|
|
|
+from alibabacloud_paistudio20210202 import models as pai_studio_20210202_models
|
|
|
+from alibabacloud_tea_util import models as util_models
|
|
|
+from alibabacloud_tea_util.client import Client as UtilClient
|
|
|
+from alibabacloud_eas20210701.client import Client as eas20210701Client
|
|
|
+from alibabacloud_paiflow20210202 import models as paiflow_20210202_models
|
|
|
+from alibabacloud_paiflow20210202.client import Client as PAIFlow20210202Client
|
|
|
+from datetime import datetime, timedelta
|
|
|
+from odps import ODPS
|
|
|
+from ad_monitor_util import _monitor
|
|
|
+
|
|
|
+target_names = {
|
|
|
+ '样本shuffle',
|
|
|
+ '模型训练-样本shufle',
|
|
|
+ '模型训练-自定义',
|
|
|
+ '模型增量训练',
|
|
|
+ '模型导出-2',
|
|
|
+ '更新EAS服务(Beta)-1',
|
|
|
+ '虚拟起始节点',
|
|
|
+ '二分类评估-1',
|
|
|
+ '二分类评估-2',
|
|
|
+ '预测结果对比'
|
|
|
+}
|
|
|
+
|
|
|
+experiment_id = "draft-kbezr8f0q3cpee9eqc"
|
|
|
+
|
|
|
+
|
|
|
+def get_odps_instance(project):
|
|
|
+ odps = ODPS(
|
|
|
+ access_id='LTAIWYUujJAm7CbH',
|
|
|
+ secret_access_key='RfSjdiWwED1sGFlsjXv0DlfTnZTG1P',
|
|
|
+ project=project,
|
|
|
+ endpoint='http://service.cn.maxcompute.aliyun.com/api',
|
|
|
+ )
|
|
|
+ return odps
|
|
|
+
|
|
|
+
|
|
|
+def get_data_from_odps(project, table, num):
|
|
|
+ odps = get_odps_instance(project)
|
|
|
+ try:
|
|
|
+ # 要查询的 SQL 语句
|
|
|
+ sql = f'select * from {table} limit {num}'
|
|
|
+ # 执行 SQL 查询
|
|
|
+ with odps.execute_sql(sql).open_reader() as reader:
|
|
|
+ # 查询数量小于目标数量时 返回空
|
|
|
+ if reader.count < num:
|
|
|
+ return None
|
|
|
+ # 获取字段名称
|
|
|
+ column_names = reader.schema.names
|
|
|
+ # 获取查询结果数据
|
|
|
+ data = []
|
|
|
+ for record in reader:
|
|
|
+ record_list = list(record)
|
|
|
+ numbers = []
|
|
|
+ for item in record_list:
|
|
|
+ numbers.append(item[1])
|
|
|
+ data.append(numbers)
|
|
|
+ # 将数据和字段名称组合成 DataFrame
|
|
|
+ df = pd.DataFrame(data, columns=column_names)
|
|
|
+ return df
|
|
|
+ except Exception as e:
|
|
|
+ print(f"发生错误: {e}")
|
|
|
+
|
|
|
+
|
|
|
+def get_dict_from_odps(project, table):
|
|
|
+ odps = get_odps_instance(project)
|
|
|
+ try:
|
|
|
+ # 要查询的 SQL 语句
|
|
|
+ sql = f'select * from {table}'
|
|
|
+ # 执行 SQL 查询
|
|
|
+ with odps.execute_sql(sql).open_reader() as reader:
|
|
|
+ data = {}
|
|
|
+ for record in reader:
|
|
|
+ record_list = list(record)
|
|
|
+ key = record_list[0][1]
|
|
|
+ value = record_list[1][1]
|
|
|
+ data[key] = value
|
|
|
+ return data
|
|
|
+ except Exception as e:
|
|
|
+ print(f"发生错误: {e}")
|
|
|
+
|
|
|
+
|
|
|
+def get_dates_between(start_date_str, end_date_str):
|
|
|
+ start_date = datetime.strptime(start_date_str, '%Y%m%d')
|
|
|
+ end_date = datetime.strptime(end_date_str, '%Y%m%d')
|
|
|
+ dates = []
|
|
|
+ current_date = start_date
|
|
|
+ while current_date <= end_date:
|
|
|
+ dates.append(current_date.strftime('%Y%m%d'))
|
|
|
+ current_date += timedelta(days=1)
|
|
|
+ return dates
|
|
|
+
|
|
|
+
|
|
|
+def read_file_to_list():
|
|
|
+ try:
|
|
|
+ current_dir = os.getcwd()
|
|
|
+ file_path = os.path.join(current_dir, 'holidays.txt')
|
|
|
+ with open(file_path, 'r', encoding='utf-8') as file:
|
|
|
+ content = file.read()
|
|
|
+ return content.split('\n')
|
|
|
+ except FileNotFoundError:
|
|
|
+ print(f"错误:未找到 {file_path} 文件。")
|
|
|
+ except Exception as e:
|
|
|
+ print(f"错误:发生了一个未知错误: {e}")
|
|
|
+ return []
|
|
|
+
|
|
|
+
|
|
|
+def get_previous_days_date(days):
|
|
|
+ current_date = datetime.now()
|
|
|
+ previous_date = current_date - timedelta(days=days)
|
|
|
+ return previous_date.strftime('%Y%m%d')
|
|
|
+
|
|
|
+
|
|
|
+def remove_elements(lst1, lst2):
|
|
|
+ return [element for element in lst1 if element not in lst2]
|
|
|
+
|
|
|
+
|
|
|
+def process_list(lst, append_str):
|
|
|
+ # 给列表中每个元素拼接相同的字符串
|
|
|
+ appended_list = [append_str + element for element in lst]
|
|
|
+ # 将拼接后的列表元素用逗号拼接成一个字符串
|
|
|
+ result_str = ','.join(appended_list)
|
|
|
+ return result_str
|
|
|
+
|
|
|
+
|
|
|
+def get_train_data_list():
|
|
|
+ start_date = '20250223'
|
|
|
+ end_date = get_previous_days_date(2)
|
|
|
+ date_list = get_dates_between(start_date, end_date)
|
|
|
+ filter_date_list = read_file_to_list()
|
|
|
+ date_list = remove_elements(date_list, filter_date_list)
|
|
|
+ return date_list
|
|
|
+
|
|
|
+
|
|
|
+def update_train_tables(old_str):
|
|
|
+ date_list = get_train_data_list()
|
|
|
+ train_list = ["'" + item + "'" for item in date_list]
|
|
|
+ result = ','.join(train_list)
|
|
|
+ start_index = old_str.find('where dt in (')
|
|
|
+ if start_index != -1:
|
|
|
+ equal_sign_index = start_index + len('where dt in (')
|
|
|
+ # 找到下一个双引号的位置
|
|
|
+ next_quote_index = old_str.find(')', equal_sign_index)
|
|
|
+ if next_quote_index != -1:
|
|
|
+ # 进行替换
|
|
|
+ new_value = old_str[:equal_sign_index] + result + old_str[next_quote_index:]
|
|
|
+ return new_value
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def update_train_table(old_str, table):
|
|
|
+ address = 'odps://pai_algo/tables/'
|
|
|
+ train_table = address + table
|
|
|
+ start_index = old_str.find('-Dtrain_tables="')
|
|
|
+ if start_index != -1:
|
|
|
+ # 确定等号的位置
|
|
|
+ equal_sign_index = start_index + len('-Dtrain_tables="')
|
|
|
+ # 找到下一个双引号的位置
|
|
|
+ next_quote_index = old_str.find('"', equal_sign_index)
|
|
|
+ if next_quote_index != -1:
|
|
|
+ # 进行替换
|
|
|
+ new_value = old_str[:equal_sign_index] + train_table + old_str[next_quote_index:]
|
|
|
+ return new_value
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+class PAIClient:
|
|
|
+ def __init__(self):
|
|
|
+ pass
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def create_client() -> PaiStudio20210202Client:
|
|
|
+ """
|
|
|
+ 使用AK&SK初始化账号Client
|
|
|
+ @return: Client
|
|
|
+ @throws Exception
|
|
|
+ """
|
|
|
+ # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
|
|
|
+ # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
|
|
|
+ config = open_api_models.Config(
|
|
|
+ access_key_id="LTAI5tFGqgC8f3mh1fRCrAEy",
|
|
|
+ access_key_secret="XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
|
|
|
+ )
|
|
|
+ # Endpoint 请参考 https://api.aliyun.com/product/PaiStudio
|
|
|
+ config.endpoint = f'pai.cn-hangzhou.aliyuncs.com'
|
|
|
+ return PaiStudio20210202Client(config)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def create_eas_client() -> eas20210701Client:
|
|
|
+ """
|
|
|
+ 使用AK&SK初始化账号Client
|
|
|
+ @return: Client
|
|
|
+ @throws Exception
|
|
|
+ """
|
|
|
+ # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
|
|
|
+ # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
|
|
|
+ config = open_api_models.Config(
|
|
|
+ access_key_id="LTAI5tFGqgC8f3mh1fRCrAEy",
|
|
|
+ access_key_secret="XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
|
|
|
+ )
|
|
|
+ # Endpoint 请参考 https://api.aliyun.com/product/PaiStudio
|
|
|
+ config.endpoint = f'pai-eas.cn-hangzhou.aliyuncs.com'
|
|
|
+ return eas20210701Client(config)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def create_flow_client() -> PAIFlow20210202Client:
|
|
|
+ """
|
|
|
+ 使用AK&SK初始化账号Client
|
|
|
+ @return: Client
|
|
|
+ @throws Exception
|
|
|
+ """
|
|
|
+ # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
|
|
|
+ # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
|
|
|
+ config = open_api_models.Config(
|
|
|
+ # 必填,请确保代码运行环境设置了环境变量 ALIBABA_CLOUD_ACCESS_KEY_ID。,
|
|
|
+ access_key_id="LTAI5tFGqgC8f3mh1fRCrAEy",
|
|
|
+ # 必填,请确保代码运行环境设置了环境变量 ALIBABA_CLOUD_ACCESS_KEY_SECRET。,
|
|
|
+ access_key_secret="XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
|
|
|
+ )
|
|
|
+ # Endpoint 请参考 https://api.aliyun.com/product/PAIFlow
|
|
|
+ config.endpoint = f'paiflow.cn-hangzhou.aliyuncs.com'
|
|
|
+ return PAIFlow20210202Client(config)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def get_work_flow_draft_list(workspace_id: str):
|
|
|
+ client = PAIClient.create_client()
|
|
|
+ list_experiments_request = pai_studio_20210202_models.ListExperimentsRequest(
|
|
|
+ workspace_id=workspace_id
|
|
|
+ )
|
|
|
+ runtime = util_models.RuntimeOptions()
|
|
|
+ headers = {}
|
|
|
+ try:
|
|
|
+ resp = client.list_experiments_with_options(list_experiments_request, headers, runtime)
|
|
|
+ return resp.body.to_map()
|
|
|
+ except Exception as error:
|
|
|
+ print(error.message)
|
|
|
+ print(error.data.get("Recommend"))
|
|
|
+ UtilClient.assert_as_string(error.message)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def get_work_flow_draft(experiment_id: str):
|
|
|
+ client = PAIClient.create_client()
|
|
|
+ runtime = util_models.RuntimeOptions()
|
|
|
+ headers = {}
|
|
|
+ try:
|
|
|
+ # 复制代码运行请自行打印 API 的返回值
|
|
|
+ resp = client.get_experiment_with_options(experiment_id, headers, runtime)
|
|
|
+ return resp.body.to_map()
|
|
|
+ except Exception as error:
|
|
|
+ # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
|
|
|
+ # 错误 message
|
|
|
+ print(error.message)
|
|
|
+ # 诊断地址
|
|
|
+ print(error.data.get("Recommend"))
|
|
|
+ UtilClient.assert_as_string(error.message)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def get_describe_service(service_name: str):
|
|
|
+ client = PAIClient.create_eas_client()
|
|
|
+ runtime = util_models.RuntimeOptions()
|
|
|
+ headers = {}
|
|
|
+ try:
|
|
|
+ # 复制代码运行请自行打印 API 的返回值
|
|
|
+ resp = client.describe_service_with_options('cn-hangzhou', service_name, headers, runtime)
|
|
|
+ return resp.body.to_map()
|
|
|
+ except Exception as error:
|
|
|
+ # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
|
|
|
+ # 错误 message
|
|
|
+ print(error.message)
|
|
|
+ # 诊断地址
|
|
|
+ print(error.data.get("Recommend"))
|
|
|
+ UtilClient.assert_as_string(error.message)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def update_experiment_content(experiment_id: str, content: str, version: int):
|
|
|
+ client = PAIClient.create_client()
|
|
|
+ update_experiment_content_request = pai_studio_20210202_models.UpdateExperimentContentRequest(content=content,
|
|
|
+ version=version)
|
|
|
+ runtime = util_models.RuntimeOptions()
|
|
|
+ headers = {}
|
|
|
+ try:
|
|
|
+ # 复制代码运行请自行打印 API 的返回值
|
|
|
+ resp = client.update_experiment_content_with_options(experiment_id, update_experiment_content_request,
|
|
|
+ headers, runtime)
|
|
|
+ print(resp.body.to_map())
|
|
|
+ except Exception as error:
|
|
|
+ # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
|
|
|
+ # 错误 message
|
|
|
+ print(error.message)
|
|
|
+ # 诊断地址
|
|
|
+ print(error.data.get("Recommend"))
|
|
|
+ UtilClient.assert_as_string(error.message)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def create_job(experiment_id: str, node_id: str, execute_type: str):
|
|
|
+ client = PAIClient.create_client()
|
|
|
+ create_job_request = pai_studio_20210202_models.CreateJobRequest()
|
|
|
+ create_job_request.experiment_id = experiment_id
|
|
|
+ create_job_request.node_id = node_id
|
|
|
+ create_job_request.execute_type = execute_type
|
|
|
+ runtime = util_models.RuntimeOptions()
|
|
|
+ headers = {}
|
|
|
+ try:
|
|
|
+ # 复制代码运行请自行打印 API 的返回值
|
|
|
+ resp = client.create_job_with_options(create_job_request, headers, runtime)
|
|
|
+ return resp.body.to_map()
|
|
|
+ except Exception as error:
|
|
|
+ # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
|
|
|
+ # 错误 message
|
|
|
+ print(error.message)
|
|
|
+ # 诊断地址
|
|
|
+ print(error.data.get("Recommend"))
|
|
|
+ UtilClient.assert_as_string(error.message)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def get_job_detail(job_id: str):
|
|
|
+ client = PAIClient.create_client()
|
|
|
+ get_job_request = pai_studio_20210202_models.GetJobRequest(
|
|
|
+ verbose=False
|
|
|
+ )
|
|
|
+ runtime = util_models.RuntimeOptions()
|
|
|
+ headers = {}
|
|
|
+ try:
|
|
|
+ # 复制代码运行请自行打印 API 的返回值
|
|
|
+ resp = client.get_job_with_options(job_id, get_job_request, headers, runtime)
|
|
|
+ return resp.body.to_map()
|
|
|
+ except Exception as error:
|
|
|
+ # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
|
|
|
+ # 错误 message
|
|
|
+ print(error.message)
|
|
|
+ # 诊断地址
|
|
|
+ print(error.data.get("Recommend"))
|
|
|
+ UtilClient.assert_as_string(error.message)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def get_flow_out_put(pipeline_run_id: str, node_id: str, depth: int):
|
|
|
+ client = PAIClient.create_flow_client()
|
|
|
+ list_pipeline_run_node_outputs_request = paiflow_20210202_models.ListPipelineRunNodeOutputsRequest(
|
|
|
+ depth=depth
|
|
|
+ )
|
|
|
+ runtime = util_models.RuntimeOptions()
|
|
|
+ headers = {}
|
|
|
+ try:
|
|
|
+ # 复制代码运行请自行打印 API 的返回值
|
|
|
+ resp = client.list_pipeline_run_node_outputs_with_options(pipeline_run_id, node_id,
|
|
|
+ list_pipeline_run_node_outputs_request, headers,
|
|
|
+ runtime)
|
|
|
+ return resp.body.to_map()
|
|
|
+ except Exception as error:
|
|
|
+ # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
|
|
|
+ # 错误 message
|
|
|
+ print(error.message)
|
|
|
+ # 诊断地址
|
|
|
+ print(error.data.get("Recommend"))
|
|
|
+ UtilClient.assert_as_string(error.message)
|
|
|
+
|
|
|
+
|
|
|
+def extract_date_yyyymmdd(input_string):
|
|
|
+ pattern = r'\d{8}'
|
|
|
+ matches = re.findall(pattern, input_string)
|
|
|
+ if matches:
|
|
|
+ return matches[0]
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def get_online_version_dt(service_name: str):
|
|
|
+ model_detail = PAIClient.get_describe_service(service_name)
|
|
|
+ service_config_str = model_detail['ServiceConfig']
|
|
|
+ service_config = json.loads(service_config_str)
|
|
|
+ model_path = service_config['model_path']
|
|
|
+ online_date = extract_date_yyyymmdd(model_path)
|
|
|
+ return online_date
|
|
|
+
|
|
|
+
|
|
|
+def update_online_flow():
|
|
|
+ online_version_dt = get_online_version_dt('ad_rank_dnn_v11_easyrec')
|
|
|
+ draft = PAIClient.get_work_flow_draft(experiment_id)
|
|
|
+ print(json.dumps(draft, ensure_ascii=False))
|
|
|
+ content = draft['Content']
|
|
|
+ version = draft['Version']
|
|
|
+ print(content)
|
|
|
+ content_json = json.loads(content)
|
|
|
+ nodes = content_json.get('nodes')
|
|
|
+ global_params = content_json.get('globalParams')
|
|
|
+ bizdate = get_previous_days_date(1)
|
|
|
+ for global_param in global_params:
|
|
|
+ if global_param['name'] == 'bizdate':
|
|
|
+ global_param['value'] = bizdate
|
|
|
+ if global_param['name'] == 'online_version_dt':
|
|
|
+ global_param['value'] = online_version_dt
|
|
|
+ if global_param['name'] == 'eval_date':
|
|
|
+ global_param['value'] = bizdate
|
|
|
+ for node in nodes:
|
|
|
+ name = node['name']
|
|
|
+ if name == '样本shuffle':
|
|
|
+ properties = node['properties']
|
|
|
+ for property in properties:
|
|
|
+ if property['name'] == 'sql':
|
|
|
+ value = property['value']
|
|
|
+ new_value = update_train_tables(value)
|
|
|
+ if new_value is None:
|
|
|
+ print("error")
|
|
|
+ property['value'] = new_value
|
|
|
+ new_content = json.dumps(content_json, ensure_ascii=False)
|
|
|
+ PAIClient.update_experiment_content(experiment_id, new_content, version)
|
|
|
+
|
|
|
+
|
|
|
+def update_shuffle_flow(table):
|
|
|
+ draft = PAIClient.get_work_flow_draft(experiment_id)
|
|
|
+ print(json.dumps(draft, ensure_ascii=False))
|
|
|
+ content = draft['Content']
|
|
|
+ version = draft['Version']
|
|
|
+ content_json = json.loads(content)
|
|
|
+ nodes = content_json.get('nodes')
|
|
|
+ for node in nodes:
|
|
|
+ name = node['name']
|
|
|
+ if name == '模型训练-样本shufle':
|
|
|
+ properties = node['properties']
|
|
|
+ for property in properties:
|
|
|
+ if property['name'] == 'sql':
|
|
|
+ value = property['value']
|
|
|
+ new_value = update_train_table(value, table)
|
|
|
+ if new_value is None:
|
|
|
+ print("error")
|
|
|
+ property['value'] = new_value
|
|
|
+ new_content = json.dumps(content_json, ensure_ascii=False)
|
|
|
+ PAIClient.update_experiment_content(experiment_id, new_content, version)
|
|
|
+
|
|
|
+
|
|
|
+def update_shuffle_flow_1():
|
|
|
+ draft = PAIClient.get_work_flow_draft(experiment_id)
|
|
|
+ print(json.dumps(draft, ensure_ascii=False))
|
|
|
+ content = draft['Content']
|
|
|
+ version = draft['Version']
|
|
|
+ print(content)
|
|
|
+ content_json = json.loads(content)
|
|
|
+ nodes = content_json.get('nodes')
|
|
|
+ for node in nodes:
|
|
|
+ name = node['name']
|
|
|
+ if name == '模型训练-样本shufle':
|
|
|
+ properties = node['properties']
|
|
|
+ for property in properties:
|
|
|
+ if property['name'] == 'sql':
|
|
|
+ value = property['value']
|
|
|
+ new_value = update_train_tables(value)
|
|
|
+ if new_value is None:
|
|
|
+ print("error")
|
|
|
+ property['value'] = new_value
|
|
|
+ new_content = json.dumps(content_json, ensure_ascii=False)
|
|
|
+ PAIClient.update_experiment_content(experiment_id, new_content, version)
|
|
|
+
|
|
|
+
|
|
|
+def wait_job_end(job_id: str):
|
|
|
+ while True:
|
|
|
+ job_detail = PAIClient.get_job_detail(job_id)
|
|
|
+ print(job_detail)
|
|
|
+ statue = job_detail['Status']
|
|
|
+ # Initialized: 初始化完成 Starting:开始 WorkflowServiceStarting:准备提交 Running:运行中 ReadyToSchedule:准备运行(前序节点未完成导致)
|
|
|
+ if (statue == 'Initialized' or statue == 'Starting' or statue == 'WorkflowServiceStarting'
|
|
|
+ or statue == 'Running' or statue == 'ReadyToSchedule'):
|
|
|
+ # 睡眠300s 等待下次获取
|
|
|
+ time.sleep(300)
|
|
|
+ continue
|
|
|
+ # Failed:运行失败 Terminating:终止中 Terminated:已终止 Unknown:未知 Skipped:跳过(前序节点失败导致) Succeeded:运行成功
|
|
|
+ if statue == 'Failed' or statue == 'Terminating' or statue == 'Unknown' or statue == 'Skipped' or statue == 'Succeeded':
|
|
|
+ return job_detail
|
|
|
+
|
|
|
+
|
|
|
+def get_node_dict():
|
|
|
+ draft = PAIClient.get_work_flow_draft(experiment_id)
|
|
|
+ content = draft['Content']
|
|
|
+ content_json = json.loads(content)
|
|
|
+ nodes = content_json.get('nodes')
|
|
|
+ node_dict = {}
|
|
|
+ for node in nodes:
|
|
|
+ name = node['name']
|
|
|
+ # 检查名称是否在目标名称集合中
|
|
|
+ if name in target_names:
|
|
|
+ node_dict[name] = node['id']
|
|
|
+ return node_dict
|
|
|
+
|
|
|
+
|
|
|
+def train_model():
|
|
|
+ node_dict = get_node_dict()
|
|
|
+ train_node_id = node_dict['样本shuffle']
|
|
|
+ execute_type = 'EXECUTE_ONE'
|
|
|
+ validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
|
|
|
+ validate_job_id = validate_res['JobId']
|
|
|
+ validate_job_detail = wait_job_end(validate_job_id)
|
|
|
+ if validate_job_detail['Status'] == 'Succeeded':
|
|
|
+ pipeline_run_id = validate_job_detail['RunId']
|
|
|
+ node_id = validate_job_detail['PaiflowNodeId']
|
|
|
+ flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 2)
|
|
|
+ out_puts = flow_out_put_detail['Outputs']
|
|
|
+ table = None
|
|
|
+ for out_put in out_puts:
|
|
|
+ if out_put["Producer"] == node_dict['样本shuffle'] and out_put["Name"] == "outputTable":
|
|
|
+ value1 = json.loads(out_put["Info"]['value'])
|
|
|
+ table = value1['location']['table']
|
|
|
+ if table is not None:
|
|
|
+ update_shuffle_flow(table)
|
|
|
+ node_dict = get_node_dict()
|
|
|
+ train_node_id = node_dict['模型训练-样本shufle']
|
|
|
+ execute_type = 'EXECUTE_ONE'
|
|
|
+ train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
|
|
|
+ train_job_id = train_res['JobId']
|
|
|
+ train_job_detail = wait_job_end(train_job_id)
|
|
|
+ if train_job_detail['Status'] == 'Succeeded':
|
|
|
+ export_node_id = node_dict['模型导出-2']
|
|
|
+ export_res = PAIClient.create_job(experiment_id, export_node_id, execute_type)
|
|
|
+ export_job_id = export_res['JobId']
|
|
|
+ export_job_detail = wait_job_end(export_job_id)
|
|
|
+ if export_job_detail['Status'] == 'Succeeded':
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
+def update_online_model():
|
|
|
+ node_dict = get_node_dict()
|
|
|
+ train_node_id = node_dict['更新EAS服务(Beta)-1']
|
|
|
+ execute_type = 'EXECUTE_ONE'
|
|
|
+ train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
|
|
|
+ train_job_id = train_res['JobId']
|
|
|
+ train_job_detail = wait_job_end(train_job_id)
|
|
|
+ if train_job_detail['Status'] == 'Succeeded':
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
+def validate_model_data_accuracy(start_time):
|
|
|
+ node_dict = get_node_dict()
|
|
|
+ train_node_id = node_dict['虚拟起始节点']
|
|
|
+ execute_type = 'EXECUTE_FROM_HERE'
|
|
|
+ validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
|
|
|
+ validate_job_id = validate_res['JobId']
|
|
|
+ validate_job_detail = wait_job_end(validate_job_id)
|
|
|
+ if validate_job_detail['Status'] == 'Succeeded':
|
|
|
+ pipeline_run_id = validate_job_detail['RunId']
|
|
|
+ node_id = validate_job_detail['PaiflowNodeId']
|
|
|
+ flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 3)
|
|
|
+ print(flow_out_put_detail)
|
|
|
+ tabel_dict = {}
|
|
|
+ out_puts = flow_out_put_detail['Outputs']
|
|
|
+ for out_put in out_puts:
|
|
|
+ if out_put["Producer"] == node_dict['二分类评估-1'] and out_put["Name"] == "outputMetricTable":
|
|
|
+ value1 = json.loads(out_put["Info"]['value'])
|
|
|
+ tabel_dict['二分类评估-1'] = value1['location']['table']
|
|
|
+ if out_put["Producer"] == node_dict['二分类评估-2'] and out_put["Name"] == "outputMetricTable":
|
|
|
+ value2 = json.loads(out_put["Info"]['value'])
|
|
|
+ tabel_dict['二分类评估-2'] = value2['location']['table']
|
|
|
+ if out_put["Producer"] == node_dict['预测结果对比'] and out_put["Name"] == "outputTable":
|
|
|
+ value3 = json.loads(out_put["Info"]['value'])
|
|
|
+ tabel_dict['预测结果对比'] = value3['location']['table']
|
|
|
+
|
|
|
+ num = 10
|
|
|
+ df = get_data_from_odps('pai_algo', tabel_dict['预测结果对比'], 10)
|
|
|
+ # 对指定列取绝对值再求和
|
|
|
+ old_abs_avg = df['old_error'].abs().sum() / num
|
|
|
+ new_abs_avg = df['new_error'].abs().sum() / num
|
|
|
+ new_auc = get_dict_from_odps('pai_algo', tabel_dict['二分类评估-1'])['AUC']
|
|
|
+ old_auc = get_dict_from_odps('pai_algo', tabel_dict['二分类评估-2'])['AUC']
|
|
|
+ bizdate = get_previous_days_date(1)
|
|
|
+ score_diff = abs(old_abs_avg - new_abs_avg)
|
|
|
+ msg = ""
|
|
|
+ level = ""
|
|
|
+ if new_abs_avg > 0.1:
|
|
|
+ msg += f'线上模型评估{bizdate}的数据,绝对误差大于0.1,请检查'
|
|
|
+ level = 'error'
|
|
|
+ elif score_diff > 0.05:
|
|
|
+ msg += f'两个模型评估${bizdate}的数据,两个模型分数差异为: ${score_diff}, 大于0.05, 请检查'
|
|
|
+ level = 'error'
|
|
|
+ else:
|
|
|
+ # update_online_model()
|
|
|
+ msg += 'DNN广告模型更新完成'
|
|
|
+ level = 'info'
|
|
|
+ step_end_time = int(time.time())
|
|
|
+ elapsed = step_end_time - start_time
|
|
|
+
|
|
|
+ # 初始化表格头部
|
|
|
+ top10_msg = "| CID | 老模型相对真实CTCVR的变化 | 新模型相对真实CTCVR的变化 |"
|
|
|
+ top10_msg += "\n| ---- | --------- | -------- |"
|
|
|
+
|
|
|
+ for index, row in df.iterrows():
|
|
|
+ # 获取指定列的元素
|
|
|
+ cid = row['cid']
|
|
|
+ old_error = row['old_error']
|
|
|
+ new_error = row['new_error']
|
|
|
+ top10_msg += f"\n| {int(cid)} | {old_error} | {new_error} | "
|
|
|
+ print(top10_msg)
|
|
|
+ msg += f"\n\t - 老模型AUC: {old_auc}"
|
|
|
+ msg += f"\n\t - 新模型AUC: {new_auc}"
|
|
|
+ msg += f"\n\t - 老模型Top10差异平均值: {old_abs_avg}"
|
|
|
+ msg += f"\n\t - 新模型Top10差异平均值: {new_abs_avg}"
|
|
|
+ _monitor(level, msg, start_time, elapsed, top10_msg)
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == '__main__':
|
|
|
+ start_time = int(time.time())
|
|
|
+ # 1.更新工作流
|
|
|
+ update_online_flow()
|
|
|
+ # 2.训练模型
|
|
|
+ train_res = train_model()
|
|
|
+ if train_res:
|
|
|
+ # 3. 验证模型数据 & 更新模型到线上
|
|
|
+ validate_model_data_accuracy(start_time)
|
|
|
+ else:
|
|
|
+ print('train_model_error')
|