|
@@ -1,11 +1,8 @@
|
|
|
# -*- coding: utf-8 -*-
|
|
|
+import functools
|
|
|
import os
|
|
|
import re
|
|
|
import sys
|
|
|
-
|
|
|
-from typing import List
|
|
|
-
|
|
|
-
|
|
|
import time
|
|
|
import json
|
|
|
import pandas as pd
|
|
@@ -34,7 +31,28 @@ target_names = {
|
|
|
'预测结果对比'
|
|
|
}
|
|
|
|
|
|
-experiment_id = "draft-kbezr8f0q3cpee9eqc"
|
|
|
+experiment_id = "draft-wqgkag89sbh9v1zvut"
|
|
|
+
|
|
|
+MAX_RETRIES = 3
|
|
|
+
|
|
|
+
|
|
|
+def retry(func):
|
|
|
+ @functools.wraps(func)
|
|
|
+ def wrapper(*args, **kwargs):
|
|
|
+ retries = 0
|
|
|
+ while retries < MAX_RETRIES:
|
|
|
+ try:
|
|
|
+ result = func(*args, **kwargs)
|
|
|
+ if result is not False:
|
|
|
+ return result
|
|
|
+ except Exception as e:
|
|
|
+ print(f"函数 {func.__name__} 执行时发生异常: {e},重试第 {retries + 1} 次")
|
|
|
+ retries += 1
|
|
|
+ print(f"函数 {func.__name__} 重试 {MAX_RETRIES} 次后仍失败。")
|
|
|
+ return False
|
|
|
+
|
|
|
+ return wrapper
|
|
|
+
|
|
|
|
|
|
def get_odps_instance(project):
|
|
|
odps = ODPS(
|
|
@@ -53,21 +71,10 @@ def get_data_from_odps(project, table, num):
|
|
|
sql = f'select * from {table} limit {num}'
|
|
|
# 执行 SQL 查询
|
|
|
with odps.execute_sql(sql).open_reader() as reader:
|
|
|
+ df = reader.to_pandas()
|
|
|
# 查询数量小于目标数量时 返回空
|
|
|
- if reader.count < num:
|
|
|
+ if len(df) < num:
|
|
|
return None
|
|
|
- # 获取字段名称
|
|
|
- column_names = reader.schema.names
|
|
|
- # 获取查询结果数据
|
|
|
- data = []
|
|
|
- for record in reader:
|
|
|
- record_list = list(record)
|
|
|
- numbers = []
|
|
|
- for item in record_list:
|
|
|
- numbers.append(item[1])
|
|
|
- data.append(numbers)
|
|
|
- # 将数据和字段名称组合成 DataFrame
|
|
|
- df = pd.DataFrame(data, columns=column_names)
|
|
|
return df
|
|
|
except Exception as e:
|
|
|
print(f"发生错误: {e}")
|
|
@@ -105,14 +112,14 @@ def get_dates_between(start_date_str, end_date_str):
|
|
|
def read_file_to_list():
|
|
|
try:
|
|
|
current_dir = os.getcwd()
|
|
|
- file_path = os.path.join(current_dir, 'holidays.txt')
|
|
|
+ file_path = os.path.join(current_dir, 'ad', 'holidays.txt')
|
|
|
with open(file_path, 'r', encoding='utf-8') as file:
|
|
|
content = file.read()
|
|
|
return content.split('\n')
|
|
|
except FileNotFoundError:
|
|
|
- print(f"错误:未找到 {file_path} 文件。")
|
|
|
+ raise Exception(f"错误:未找到 {file_path} 文件。")
|
|
|
except Exception as e:
|
|
|
- print(f"错误:发生了一个未知错误: {e}")
|
|
|
+ raise Exception(f"错误:发生了一个未知错误: {e}")
|
|
|
return []
|
|
|
|
|
|
|
|
@@ -135,7 +142,7 @@ def process_list(lst, append_str):
|
|
|
|
|
|
|
|
|
def get_train_data_list():
|
|
|
- start_date = '20250223'
|
|
|
+ start_date = '20250320'
|
|
|
end_date = get_previous_days_date(2)
|
|
|
date_list = get_dates_between(start_date, end_date)
|
|
|
filter_date_list = read_file_to_list()
|
|
@@ -145,32 +152,44 @@ def get_train_data_list():
|
|
|
|
|
|
def update_train_tables(old_str):
|
|
|
date_list = get_train_data_list()
|
|
|
- address = 'odps://loghubods/tables/ad_easyrec_train_data_v3_sampled/dt='
|
|
|
- train_tables = process_list(date_list, address)
|
|
|
- start_index = old_str.find('-Dtrain_tables="')
|
|
|
+ train_list = ["'" + item + "'" for item in date_list]
|
|
|
+ result = ','.join(train_list)
|
|
|
+ start_index = old_str.find('where dt in (')
|
|
|
if start_index != -1:
|
|
|
- # 确定等号的位置
|
|
|
- equal_sign_index = start_index + len('-Dtrain_tables="')
|
|
|
+ equal_sign_index = start_index + len('where dt in (')
|
|
|
# 找到下一个双引号的位置
|
|
|
- next_quote_index = old_str.find('"', equal_sign_index)
|
|
|
+ next_quote_index = old_str.find(')', equal_sign_index)
|
|
|
if next_quote_index != -1:
|
|
|
# 进行替换
|
|
|
- new_value = old_str[:equal_sign_index] + train_tables + old_str[next_quote_index:]
|
|
|
+ new_value = old_str[:equal_sign_index] + result + old_str[next_quote_index:]
|
|
|
return new_value
|
|
|
return None
|
|
|
|
|
|
-def new_update_train_tables(old_str):
|
|
|
- date_list = get_train_data_list()
|
|
|
- train_list = ["'" + item + "'" for item in date_list]
|
|
|
- result = ','.join(train_list)
|
|
|
- start_index = old_str.find('where dt in (')
|
|
|
+
|
|
|
+def compare_timestamp_with_today_start(time_str):
|
|
|
+ # 解析时间字符串为 datetime 对象
|
|
|
+ time_obj = datetime.fromisoformat(time_str)
|
|
|
+ # 将其转换为时间戳
|
|
|
+ target_timestamp = time_obj.timestamp()
|
|
|
+ # 获取今天开始的时间
|
|
|
+ today_start = datetime.combine(datetime.now().date(), datetime.min.time())
|
|
|
+ # 将今天开始时间转换为时间戳
|
|
|
+ today_start_timestamp = today_start.timestamp()
|
|
|
+ return target_timestamp > today_start_timestamp
|
|
|
+
|
|
|
+
|
|
|
+def update_train_table(old_str, table):
|
|
|
+ address = 'odps://pai_algo/tables/'
|
|
|
+ train_table = address + table
|
|
|
+ start_index = old_str.find('-Dtrain_tables="')
|
|
|
if start_index != -1:
|
|
|
- equal_sign_index = start_index + len('where dt in (')
|
|
|
+ # 确定等号的位置
|
|
|
+ equal_sign_index = start_index + len('-Dtrain_tables="')
|
|
|
# 找到下一个双引号的位置
|
|
|
- next_quote_index = old_str.find(')', equal_sign_index)
|
|
|
+ next_quote_index = old_str.find('"', equal_sign_index)
|
|
|
if next_quote_index != -1:
|
|
|
# 进行替换
|
|
|
- new_value = old_str[:equal_sign_index] + result + old_str[next_quote_index:]
|
|
|
+ new_value = old_str[:equal_sign_index] + train_table + old_str[next_quote_index:]
|
|
|
return new_value
|
|
|
return None
|
|
|
|
|
@@ -244,9 +263,7 @@ class PAIClient:
|
|
|
resp = client.list_experiments_with_options(list_experiments_request, headers, runtime)
|
|
|
return resp.body.to_map()
|
|
|
except Exception as error:
|
|
|
- print(error.message)
|
|
|
- print(error.data.get("Recommend"))
|
|
|
- UtilClient.assert_as_string(error.message)
|
|
|
+ raise Exception(f"get_work_flow_draft_list error {error}")
|
|
|
|
|
|
@staticmethod
|
|
|
def get_work_flow_draft(experiment_id: str):
|
|
@@ -258,12 +275,7 @@ class PAIClient:
|
|
|
resp = client.get_experiment_with_options(experiment_id, headers, runtime)
|
|
|
return resp.body.to_map()
|
|
|
except Exception as error:
|
|
|
- # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
|
|
|
- # 错误 message
|
|
|
- print(error.message)
|
|
|
- # 诊断地址
|
|
|
- print(error.data.get("Recommend"))
|
|
|
- UtilClient.assert_as_string(error.message)
|
|
|
+ raise Exception(f"get_work_flow_draft error {error}")
|
|
|
|
|
|
@staticmethod
|
|
|
def get_describe_service(service_name: str):
|
|
@@ -275,12 +287,7 @@ class PAIClient:
|
|
|
resp = client.describe_service_with_options('cn-hangzhou', service_name, headers, runtime)
|
|
|
return resp.body.to_map()
|
|
|
except Exception as error:
|
|
|
- # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
|
|
|
- # 错误 message
|
|
|
- print(error.message)
|
|
|
- # 诊断地址
|
|
|
- print(error.data.get("Recommend"))
|
|
|
- UtilClient.assert_as_string(error.message)
|
|
|
+ raise Exception(f"get_describe_service error {error}")
|
|
|
|
|
|
@staticmethod
|
|
|
def update_experiment_content(experiment_id: str, content: str, version: int):
|
|
@@ -295,12 +302,7 @@ class PAIClient:
|
|
|
headers, runtime)
|
|
|
print(resp.body.to_map())
|
|
|
except Exception as error:
|
|
|
- # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
|
|
|
- # 错误 message
|
|
|
- print(error.message)
|
|
|
- # 诊断地址
|
|
|
- print(error.data.get("Recommend"))
|
|
|
- UtilClient.assert_as_string(error.message)
|
|
|
+ raise Exception(f"update_experiment_content error {error}")
|
|
|
|
|
|
@staticmethod
|
|
|
def create_job(experiment_id: str, node_id: str, execute_type: str):
|
|
@@ -316,18 +318,29 @@ class PAIClient:
|
|
|
resp = client.create_job_with_options(create_job_request, headers, runtime)
|
|
|
return resp.body.to_map()
|
|
|
except Exception as error:
|
|
|
- # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
|
|
|
- # 错误 message
|
|
|
- print(error.message)
|
|
|
- # 诊断地址
|
|
|
- print(error.data.get("Recommend"))
|
|
|
- UtilClient.assert_as_string(error.message)
|
|
|
+ raise Exception(f"create_job error {error}")
|
|
|
|
|
|
@staticmethod
|
|
|
- def get_job_detail(job_id: str):
|
|
|
+ def get_jobs_list(experiment_id: str, order='DESC'):
|
|
|
+ client = PAIClient.create_client()
|
|
|
+ list_jobs_request = pai_studio_20210202_models.ListJobsRequest(
|
|
|
+ experiment_id=experiment_id,
|
|
|
+ order=order
|
|
|
+ )
|
|
|
+ runtime = util_models.RuntimeOptions()
|
|
|
+ headers = {}
|
|
|
+ try:
|
|
|
+ # 复制代码运行请自行打印 API 的返回值
|
|
|
+ resp = client.list_jobs_with_options(list_jobs_request, headers, runtime)
|
|
|
+ return resp.body.to_map()
|
|
|
+ except Exception as error:
|
|
|
+ raise Exception(f"get_jobs_list error {error}")
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def get_job_detail(job_id: str, verbose=False):
|
|
|
client = PAIClient.create_client()
|
|
|
get_job_request = pai_studio_20210202_models.GetJobRequest(
|
|
|
- verbose=False
|
|
|
+ verbose=verbose
|
|
|
)
|
|
|
runtime = util_models.RuntimeOptions()
|
|
|
headers = {}
|
|
@@ -374,41 +387,34 @@ def extract_date_yyyymmdd(input_string):
|
|
|
return None
|
|
|
|
|
|
|
|
|
-def get_online_version_dt(service_name: str):
|
|
|
+def get_online_model_config(service_name: str):
|
|
|
+ model_config = {}
|
|
|
model_detail = PAIClient.get_describe_service(service_name)
|
|
|
service_config_str = model_detail['ServiceConfig']
|
|
|
service_config = json.loads(service_config_str)
|
|
|
model_path = service_config['model_path']
|
|
|
+ model_config['model_path'] = model_path
|
|
|
online_date = extract_date_yyyymmdd(model_path)
|
|
|
- return online_date
|
|
|
+ model_config['online_date'] = online_date
|
|
|
+ return model_config
|
|
|
|
|
|
|
|
|
-def update_online_flow():
|
|
|
- online_version_dt = get_online_version_dt('ad_rank_dnn_v11_easyrec')
|
|
|
+
|
|
|
+def update_shuffle_flow(table):
|
|
|
draft = PAIClient.get_work_flow_draft(experiment_id)
|
|
|
print(json.dumps(draft, ensure_ascii=False))
|
|
|
content = draft['Content']
|
|
|
version = draft['Version']
|
|
|
- print(content)
|
|
|
content_json = json.loads(content)
|
|
|
nodes = content_json.get('nodes')
|
|
|
- global_params = content_json.get('globalParams')
|
|
|
- bizdate = get_previous_days_date(1)
|
|
|
- for global_param in global_params:
|
|
|
- if global_param['name'] == 'bizdate':
|
|
|
- global_param['value'] = bizdate
|
|
|
- if global_param['name'] == 'online_version_dt':
|
|
|
- global_param['value'] = online_version_dt
|
|
|
- if global_param['name'] == 'eval_date':
|
|
|
- global_param['value'] = bizdate
|
|
|
for node in nodes:
|
|
|
name = node['name']
|
|
|
- if name == '模型训练-自定义':
|
|
|
+ if name == '模型训练-样本shufle':
|
|
|
properties = node['properties']
|
|
|
for property in properties:
|
|
|
if property['name'] == 'sql':
|
|
|
value = property['value']
|
|
|
- new_value = update_train_tables(value)
|
|
|
+ new_value = update_train_table(value, table)
|
|
|
if new_value is None:
|
|
|
print("error")
|
|
|
property['value'] = new_value
|
|
@@ -416,8 +422,7 @@ def update_online_flow():
|
|
|
PAIClient.update_experiment_content(experiment_id, new_content, version)
|
|
|
|
|
|
|
|
|
-def update_online_new_flow():
|
|
|
- online_version_dt = get_online_version_dt('ad_rank_dnn_v11_easyrec')
|
|
|
+def update_shuffle_flow_1():
|
|
|
draft = PAIClient.get_work_flow_draft(experiment_id)
|
|
|
print(json.dumps(draft, ensure_ascii=False))
|
|
|
content = draft['Content']
|
|
@@ -425,23 +430,14 @@ def update_online_new_flow():
|
|
|
print(content)
|
|
|
content_json = json.loads(content)
|
|
|
nodes = content_json.get('nodes')
|
|
|
- global_params = content_json.get('globalParams')
|
|
|
- bizdate = get_previous_days_date(1)
|
|
|
- for global_param in global_params:
|
|
|
- if global_param['name'] == 'bizdate':
|
|
|
- global_param['value'] = bizdate
|
|
|
- if global_param['name'] == 'online_version_dt':
|
|
|
- global_param['value'] = online_version_dt
|
|
|
- if global_param['name'] == 'eval_date':
|
|
|
- global_param['value'] = bizdate
|
|
|
for node in nodes:
|
|
|
name = node['name']
|
|
|
- if name == '样本shuffle':
|
|
|
+ if name == '模型训练-样本shufle':
|
|
|
properties = node['properties']
|
|
|
for property in properties:
|
|
|
if property['name'] == 'sql':
|
|
|
value = property['value']
|
|
|
- new_value = new_update_train_tables(value)
|
|
|
+ new_value = update_train_tables(value)
|
|
|
if new_value is None:
|
|
|
print("error")
|
|
|
property['value'] = new_value
|
|
@@ -466,7 +462,6 @@ def wait_job_end(job_id: str):
|
|
|
|
|
|
|
|
|
def get_node_dict():
|
|
|
- experiment_id = "draft-7u3e9v1uc5pohjl0t6"
|
|
|
draft = PAIClient.get_work_flow_draft(experiment_id)
|
|
|
content = draft['Content']
|
|
|
content_json = json.loads(content)
|
|
@@ -480,76 +475,210 @@ def get_node_dict():
|
|
|
return node_dict
|
|
|
|
|
|
|
|
|
-def train_model():
|
|
|
- node_dict = get_node_dict()
|
|
|
- train_node_id = node_dict['模型训练-自定义']
|
|
|
- execute_type = 'EXECUTE_ONE'
|
|
|
- train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
|
|
|
- train_job_id = train_res['JobId']
|
|
|
- train_job_detail = wait_job_end(train_job_id)
|
|
|
- if train_job_detail['Status'] == 'Succeeded':
|
|
|
+def get_job_dict():
|
|
|
+ job_dict = {}
|
|
|
+ jobs_list = PAIClient.get_jobs_list(experiment_id)
|
|
|
+ for job in jobs_list['Jobs']:
|
|
|
+ # 解析时间字符串为 datetime 对象
|
|
|
+ if not compare_timestamp_with_today_start(job['GmtCreateTime']):
|
|
|
+ break
|
|
|
+ job_id = job['JobId']
|
|
|
+ job_detail = PAIClient.get_job_detail(job_id, verbose=True)
|
|
|
+ for name in target_names:
|
|
|
+ if job_detail['Status'] != 'Succeeded':
|
|
|
+ continue
|
|
|
+ if name in job_dict:
|
|
|
+ continue
|
|
|
+ if name in job_detail['RunInfo']:
|
|
|
+ job_dict[name] = job_detail['JobId']
|
|
|
+ return job_dict
|
|
|
+
|
|
|
+@retry
|
|
|
+def update_online_flow():
|
|
|
+ try:
|
|
|
+ online_model_config = get_online_model_config('ad_rank_dnn_v11_easyrec')
|
|
|
+ draft = PAIClient.get_work_flow_draft(experiment_id)
|
|
|
+ print(json.dumps(draft, ensure_ascii=False))
|
|
|
+ content = draft['Content']
|
|
|
+ version = draft['Version']
|
|
|
+ print(content)
|
|
|
+ content_json = json.loads(content)
|
|
|
+ nodes = content_json.get('nodes')
|
|
|
+ global_params = content_json.get('globalParams')
|
|
|
+ bizdate = get_previous_days_date(1)
|
|
|
+ for global_param in global_params:
|
|
|
+ try:
|
|
|
+ if global_param['name'] == 'bizdate':
|
|
|
+ global_param['value'] = bizdate
|
|
|
+ if global_param['name'] == 'online_version_dt':
|
|
|
+ global_param['value'] = online_model_config['online_date']
|
|
|
+ if global_param['name'] == 'eval_date':
|
|
|
+ global_param['value'] = bizdate
|
|
|
+ if global_param['name'] == 'online_model_path':
|
|
|
+ global_param['value'] = online_model_config['model_path']
|
|
|
+ except KeyError:
|
|
|
+ raise Exception("在处理全局参数时,字典中缺少必要的键")
|
|
|
+ for node in nodes:
|
|
|
+ try:
|
|
|
+ name = node['name']
|
|
|
+ if name == '样本shuffle':
|
|
|
+ properties = node['properties']
|
|
|
+ for property in properties:
|
|
|
+ if property['name'] == 'sql':
|
|
|
+ value = property['value']
|
|
|
+ new_value = update_train_tables(value)
|
|
|
+ if new_value is None:
|
|
|
+ print("error")
|
|
|
+ property['value'] = new_value
|
|
|
+ except KeyError:
|
|
|
+ raise Exception("在处理节点属性时,字典中缺少必要的键")
|
|
|
+ new_content = json.dumps(content_json, ensure_ascii=False)
|
|
|
+ PAIClient.update_experiment_content(experiment_id, new_content, version)
|
|
|
+ return True
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ raise Exception("JSON 解析错误,可能是草稿内容格式不正确")
|
|
|
+ except Exception as e:
|
|
|
+ raise Exception(f"发生未知错误: {e}")
|
|
|
+
|
|
|
+@retry
|
|
|
+def shuffle_table():
|
|
|
+ try:
|
|
|
+ node_dict = get_node_dict()
|
|
|
+ train_node_id = node_dict['样本shuffle']
|
|
|
+ execute_type = 'EXECUTE_ONE'
|
|
|
+ validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
|
|
|
+ validate_job_id = validate_res['JobId']
|
|
|
+ validate_job_detail = wait_job_end(validate_job_id)
|
|
|
+ if validate_job_detail['Status'] == 'Succeeded':
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+ except Exception as e:
|
|
|
+ error_message = f"在执行 shuffle_table 函数时发生异常: {str(e)}"
|
|
|
+ print(error_message)
|
|
|
+ raise Exception(error_message)
|
|
|
+
|
|
|
+
|
|
|
+@retry
|
|
|
+def shuffle_train_model():
|
|
|
+ try:
|
|
|
+ node_dict = get_node_dict()
|
|
|
+ job_dict = get_job_dict()
|
|
|
+ job_id = job_dict['样本shuffle']
|
|
|
+ validate_job_detail = wait_job_end(job_id)
|
|
|
+ if validate_job_detail['Status'] == 'Succeeded':
|
|
|
+ pipeline_run_id = validate_job_detail['RunId']
|
|
|
+ node_id = validate_job_detail['PaiflowNodeId']
|
|
|
+ flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 2)
|
|
|
+ out_puts = flow_out_put_detail['Outputs']
|
|
|
+ table = None
|
|
|
+ for out_put in out_puts:
|
|
|
+ if out_put["Producer"] == node_dict['样本shuffle'] and out_put["Name"] == "outputTable":
|
|
|
+ value1 = json.loads(out_put["Info"]['value'])
|
|
|
+ table = value1['location']['table']
|
|
|
+ if table is not None:
|
|
|
+ update_shuffle_flow(table)
|
|
|
+ node_dict = get_node_dict()
|
|
|
+ train_node_id = node_dict['模型训练-样本shufle']
|
|
|
+ execute_type = 'EXECUTE_ONE'
|
|
|
+ train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
|
|
|
+ train_job_id = train_res['JobId']
|
|
|
+ train_job_detail = wait_job_end(train_job_id)
|
|
|
+ if train_job_detail['Status'] == 'Succeeded':
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+ except Exception as e:
|
|
|
+ error_message = f"在执行 shuffle_train_model 函数时发生异常: {str(e)}"
|
|
|
+ print(error_message)
|
|
|
+ raise Exception(error_message)
|
|
|
+
|
|
|
+
|
|
|
+@retry
|
|
|
+def export_model():
|
|
|
+ try:
|
|
|
+ node_dict = get_node_dict()
|
|
|
export_node_id = node_dict['模型导出-2']
|
|
|
+ execute_type = 'EXECUTE_ONE'
|
|
|
export_res = PAIClient.create_job(experiment_id, export_node_id, execute_type)
|
|
|
export_job_id = export_res['JobId']
|
|
|
export_job_detail = wait_job_end(export_job_id)
|
|
|
if export_job_detail['Status'] == 'Succeeded':
|
|
|
return True
|
|
|
- return False
|
|
|
+ return False
|
|
|
+ except Exception as e:
|
|
|
+ error_message = f"在执行 export_model 函数时发生异常: {str(e)}"
|
|
|
+ print(error_message)
|
|
|
+ raise Exception(error_message)
|
|
|
|
|
|
|
|
|
def update_online_model():
|
|
|
- node_dict = get_node_dict()
|
|
|
- experiment_id = "draft-7u3e9v1uc5pohjl0t6"
|
|
|
- train_node_id = node_dict['更新EAS服务(Beta)-1']
|
|
|
- execute_type = 'EXECUTE_ONE'
|
|
|
- train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
|
|
|
- train_job_id = train_res['JobId']
|
|
|
- train_job_detail = wait_job_end(train_job_id)
|
|
|
- if train_job_detail['Status'] == 'Succeeded':
|
|
|
- return True
|
|
|
- return False
|
|
|
+ try:
|
|
|
+ node_dict = get_node_dict()
|
|
|
+ train_node_id = node_dict['更新EAS服务(Beta)-1']
|
|
|
+ execute_type = 'EXECUTE_ONE'
|
|
|
+ train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
|
|
|
+ train_job_id = train_res['JobId']
|
|
|
+ train_job_detail = wait_job_end(train_job_id)
|
|
|
+ if train_job_detail['Status'] == 'Succeeded':
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+ except Exception as e:
|
|
|
+ error_message = f"在执行 update_online_model 函数时发生异常: {str(e)}"
|
|
|
+ print(error_message)
|
|
|
+ raise Exception(error_message)
|
|
|
|
|
|
|
|
|
-def validate_model_data_accuracy():
|
|
|
- node_dict = get_node_dict()
|
|
|
- train_node_id = node_dict['虚拟起始节点']
|
|
|
- execute_type = 'EXECUTE_FROM_HERE'
|
|
|
- validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
|
|
|
- print(validate_res)
|
|
|
- validate_job_id = validate_res['JobId']
|
|
|
- print(validate_job_id)
|
|
|
- validate_job_detail = wait_job_end(validate_job_id)
|
|
|
- print('res')
|
|
|
- print(validate_job_detail)
|
|
|
- if validate_job_detail['Status'] == 'Succeeded':
|
|
|
- pipeline_run_id = validate_job_detail['RunId']
|
|
|
- node_id = validate_job_detail['PaiflowNodeId']
|
|
|
- flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 3)
|
|
|
- print(flow_out_put_detail)
|
|
|
- tabel_dict = {}
|
|
|
- out_puts = flow_out_put_detail['Outputs']
|
|
|
- for out_put in out_puts:
|
|
|
- if out_put["Producer"] == node_dict['二分类评估-1'] and out_put["Name"] == "outputMetricTable":
|
|
|
- value1 = json.loads(out_put["Info"]['value'])
|
|
|
- tabel_dict['二分类评估-1'] = value1['location']['table']
|
|
|
- if out_put["Producer"] == node_dict['二分类评估-2'] and out_put["Name"] == "outputMetricTable":
|
|
|
- value2 = json.loads(out_put["Info"]['value'])
|
|
|
- tabel_dict['二分类评估-2'] = value2['location']['table']
|
|
|
- if out_put["Producer"] == node_dict['预测结果对比'] and out_put["Name"] == "outputTable":
|
|
|
- value3 = json.loads(out_put["Info"]['value'])
|
|
|
- tabel_dict['预测结果对比'] = value3['location']['table']
|
|
|
+@retry
|
|
|
+def get_validate_model_data():
|
|
|
+ try:
|
|
|
+ node_dict = get_node_dict()
|
|
|
+ train_node_id = node_dict['虚拟起始节点']
|
|
|
+ execute_type = 'EXECUTE_FROM_HERE'
|
|
|
+ validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
|
|
|
+ validate_job_id = validate_res['JobId']
|
|
|
+ validate_job_detail = wait_job_end(validate_job_id)
|
|
|
+ if validate_job_detail['Status'] == 'Succeeded':
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+ except Exception as e:
|
|
|
+ error_message = f"在执行 get_validate_model_data 函数时出现异常: {e}"
|
|
|
+ print(error_message)
|
|
|
+ raise Exception(error_message)
|
|
|
+
|
|
|
|
|
|
+def validate_model_data_accuracy():
|
|
|
+ try:
|
|
|
+ table_dict = {}
|
|
|
+ node_dict = get_node_dict()
|
|
|
+ job_dict = get_job_dict()
|
|
|
+ job_id = job_dict['虚拟起始节点']
|
|
|
+ validate_job_detail = wait_job_end(job_id)
|
|
|
+ if validate_job_detail['Status'] == 'Succeeded':
|
|
|
+ pipeline_run_id = validate_job_detail['RunId']
|
|
|
+ node_id = validate_job_detail['PaiflowNodeId']
|
|
|
+ flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 3)
|
|
|
+ print(flow_out_put_detail)
|
|
|
+ out_puts = flow_out_put_detail['Outputs']
|
|
|
+ for out_put in out_puts:
|
|
|
+ if out_put["Producer"] == node_dict['二分类评估-1'] and out_put["Name"] == "outputMetricTable":
|
|
|
+ value1 = json.loads(out_put["Info"]['value'])
|
|
|
+ table_dict['二分类评估-1'] = value1['location']['table']
|
|
|
+ if out_put["Producer"] == node_dict['二分类评估-2'] and out_put["Name"] == "outputMetricTable":
|
|
|
+ value2 = json.loads(out_put["Info"]['value'])
|
|
|
+ table_dict['二分类评估-2'] = value2['location']['table']
|
|
|
+ if out_put["Producer"] == node_dict['预测结果对比'] and out_put["Name"] == "outputTable":
|
|
|
+ value3 = json.loads(out_put["Info"]['value'])
|
|
|
+ table_dict['预测结果对比'] = value3['location']['table']
|
|
|
num = 10
|
|
|
- df = get_data_from_odps('pai_algo', tabel_dict['预测结果对比'], 10)
|
|
|
+ df = get_data_from_odps('pai_algo', table_dict['预测结果对比'], 10)
|
|
|
# 对指定列取绝对值再求和
|
|
|
old_abs_avg = df['old_error'].abs().sum() / num
|
|
|
new_abs_avg = df['new_error'].abs().sum() / num
|
|
|
- new_auc = get_dict_from_odps('pai_algo', tabel_dict['二分类评估-1'])['AUC']
|
|
|
- old_auc = get_dict_from_odps('pai_algo', tabel_dict['二分类评估-2'])['AUC']
|
|
|
+ new_auc = get_dict_from_odps('pai_algo', table_dict['二分类评估-1'])['AUC']
|
|
|
+ old_auc = get_dict_from_odps('pai_algo', table_dict['二分类评估-2'])['AUC']
|
|
|
bizdate = get_previous_days_date(1)
|
|
|
score_diff = abs(old_abs_avg - new_abs_avg)
|
|
|
msg = ""
|
|
|
- level = ""
|
|
|
+ result = False
|
|
|
if new_abs_avg > 0.1:
|
|
|
msg += f'线上模型评估{bizdate}的数据,绝对误差大于0.1,请检查'
|
|
|
level = 'error'
|
|
@@ -557,11 +686,9 @@ def validate_model_data_accuracy():
|
|
|
msg += f'两个模型评估${bizdate}的数据,两个模型分数差异为: ${score_diff}, 大于0.05, 请检查'
|
|
|
level = 'error'
|
|
|
else:
|
|
|
- # update_online_model()
|
|
|
msg += 'DNN广告模型更新完成'
|
|
|
level = 'info'
|
|
|
- step_end_time = int(time.time())
|
|
|
- elapsed = step_end_time - start_time
|
|
|
+ result = True
|
|
|
|
|
|
# 初始化表格头部
|
|
|
top10_msg = "| CID | 老模型相对真实CTCVR的变化 | 新模型相对真实CTCVR的变化 |"
|
|
@@ -578,78 +705,44 @@ def validate_model_data_accuracy():
|
|
|
msg += f"\n\t - 新模型AUC: {new_auc}"
|
|
|
msg += f"\n\t - 老模型Top10差异平均值: {old_abs_avg}"
|
|
|
msg += f"\n\t - 新模型Top10差异平均值: {new_abs_avg}"
|
|
|
- _monitor(level, msg, start_time, elapsed, top10_msg)
|
|
|
+ return result, msg, level, top10_msg
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ error_message = f"在执行 validate_model_data_accuracy 函数时出现异常: {str(e)}"
|
|
|
+ print(error_message)
|
|
|
+ raise Exception(error_message)
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
start_time = int(time.time())
|
|
|
- # 1.更新工作流
|
|
|
- update_online_new_flow()
|
|
|
-
|
|
|
- # 2.训练模型
|
|
|
- train_model()
|
|
|
-
|
|
|
- # 3. 验证模型数据 & 更新模型到线上
|
|
|
- # validate_model_data_accuracy()
|
|
|
- # start_time = int(time.time())
|
|
|
- # node_dict = get_node_dict()
|
|
|
- # str = '{"Creator": "204034041838504386", "ExecuteType": "EXECUTE_FROM_HERE", "ExperimentId": "draft-7u3e9v1uc5pohjl0t6", "GmtCreateTime": "2025-04-01T03:17:42.000+00:00", "JobId": "job-8u3ev2uf5ncoexj9p9", "PaiflowNodeId": "node-9wtveoz1tu89tqfoox", "RequestId": "6ED5FFB1-346B-5075-ACC9-029EB77E9F09", "RunId": "flow-lchat027733ttstdc0", "Status": "Succeeded", "WorkspaceId": "96094"}'
|
|
|
- # validate_job_detail = json.loads(str)
|
|
|
- # if validate_job_detail['Status'] == 'Succeeded':
|
|
|
- # pipeline_run_id = validate_job_detail['RunId']
|
|
|
- # node_id = validate_job_detail['PaiflowNodeId']
|
|
|
- # flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 3)
|
|
|
- # print(flow_out_put_detail)
|
|
|
- # tabel_dict = {}
|
|
|
- # out_puts = flow_out_put_detail['Outputs']
|
|
|
- # for out_put in out_puts:
|
|
|
- # if out_put["Producer"] == node_dict['二分类评估-1'] and out_put["Name"] == "outputMetricTable":
|
|
|
- # value1 = json.loads(out_put["Info"]['value'])
|
|
|
- # tabel_dict['二分类评估-1'] = value1['location']['table']
|
|
|
- # if out_put["Producer"] == node_dict['二分类评估-2'] and out_put["Name"] == "outputMetricTable":
|
|
|
- # value2 = json.loads(out_put["Info"]['value'])
|
|
|
- # tabel_dict['二分类评估-2'] = value2['location']['table']
|
|
|
- # if out_put["Producer"] == node_dict['预测结果对比'] and out_put["Name"] == "outputTable":
|
|
|
- # value3 = json.loads(out_put["Info"]['value'])
|
|
|
- # tabel_dict['预测结果对比'] = value3['location']['table']
|
|
|
- #
|
|
|
- # num = 10
|
|
|
- # df = get_data_from_odps('pai_algo', tabel_dict['预测结果对比'], 10)
|
|
|
- # # 对指定列取绝对值再求和
|
|
|
- # old_abs_avg = df['old_error'].abs().sum() / num
|
|
|
- # new_abs_avg = df['new_error'].abs().sum() / num
|
|
|
- # new_auc = get_dict_from_odps('pai_algo', tabel_dict['二分类评估-1'])['AUC']
|
|
|
- # old_auc = get_dict_from_odps('pai_algo', tabel_dict['二分类评估-2'])['AUC']
|
|
|
- # bizdate = get_previous_days_date(1)
|
|
|
- # score_diff = abs(old_abs_avg - new_abs_avg)
|
|
|
- # msg = ""
|
|
|
- # level = ""
|
|
|
- # if new_abs_avg > 0.1:
|
|
|
- # msg += f'线上模型评估{bizdate}的数据,绝对误差大于0.1,请检查'
|
|
|
- # level = 'error'
|
|
|
- # elif score_diff > 0.05:
|
|
|
- # msg += f'两个模型评估${bizdate}的数据,两个模型分数差异为: ${score_diff}, 大于0.05, 请检查'
|
|
|
- # level = 'error'
|
|
|
- # else:
|
|
|
- # # TODO 更新模型到线上
|
|
|
- # msg += 'DNN广告模型更新完成'
|
|
|
- # level = 'info'
|
|
|
- # step_end_time = int(time.time())
|
|
|
- # elapsed = step_end_time - start_time
|
|
|
- #
|
|
|
- # # 初始化表格头部
|
|
|
- # top10_msg = "| CID | 老模型相对真实CTCVR的变化 | 新模型相对真实CTCVR的变化 |"
|
|
|
- # top10_msg += "\n| ---- | --------- | -------- |"
|
|
|
- #
|
|
|
- # for index, row in df.iterrows():
|
|
|
- # # 获取指定列的元素
|
|
|
- # cid = row['cid']
|
|
|
- # old_error = row['old_error']
|
|
|
- # new_error = row['new_error']
|
|
|
- # top10_msg += f"\n| {int(cid)} | {old_error} | {new_error} | "
|
|
|
- # print(top10_msg)
|
|
|
- # msg += f"\n\t - 老模型AUC: {old_auc}"
|
|
|
- # msg += f"\n\t - 新模型AUC: {new_auc}"
|
|
|
- # msg += f"\n\t - 老模型Top10差异平均值: {old_abs_avg}"
|
|
|
- # msg += f"\n\t - 新模型Top10差异平均值: {new_abs_avg}"
|
|
|
- # _monitor('info', msg, start_time, elapsed, top10_msg)
|
|
|
+ functions = [update_online_flow, shuffle_table, shuffle_train_model, export_model, get_validate_model_data]
|
|
|
+ function_names = [func.__name__ for func in functions]
|
|
|
+
|
|
|
+ start_function = None
|
|
|
+ if len(sys.argv) > 1:
|
|
|
+ start_function = sys.argv[1]
|
|
|
+ if start_function not in function_names:
|
|
|
+ print(f"指定的起始函数 {start_function} 不存在,请选择以下函数之一:{', '.join(function_names)}")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ start_index = 0
|
|
|
+ if start_function:
|
|
|
+ start_index = function_names.index(start_function)
|
|
|
+
|
|
|
+ for func in functions[start_index:]:
|
|
|
+ if not func():
|
|
|
+ print(f"{func.__name__} 执行失败,后续函数不再执行。")
|
|
|
+ step_end_time = int(time.time())
|
|
|
+ elapsed = step_end_time - start_time
|
|
|
+ _monitor('error', f"DNN模型更新,{func.__name__} 执行失败,后续函数不再执行,请检查", start_time, elapsed, None)
|
|
|
+ break
|
|
|
+ else:
|
|
|
+ print("所有函数都成功执行,可以继续下一步操作。")
|
|
|
+ result, msg, level, top10_msg = validate_model_data_accuracy()
|
|
|
+ if result:
|
|
|
+ # update_online_model()
|
|
|
+ print("success")
|
|
|
+ step_end_time = int(time.time())
|
|
|
+ elapsed = step_end_time - start_time
|
|
|
+ print(level, msg, start_time, elapsed, top10_msg)
|
|
|
+ _monitor(level, msg, start_time, elapsed, top10_msg)
|