|
@@ -20,6 +20,7 @@ from ad_monitor_util import _monitor
|
|
|
|
|
|
target_names = {
|
|
|
'样本shuffle',
|
|
|
+ '评估样本重组',
|
|
|
'模型训练-样本shufle',
|
|
|
'模型训练-自定义',
|
|
|
'模型增量训练',
|
|
@@ -143,14 +144,14 @@ def process_list(lst, append_str):
|
|
|
|
|
|
def get_train_data_list():
|
|
|
start_date = '20250320'
|
|
|
- end_date = get_previous_days_date(2)
|
|
|
+ end_date = get_previous_days_date(1)
|
|
|
date_list = get_dates_between(start_date, end_date)
|
|
|
filter_date_list = read_file_to_list()
|
|
|
date_list = remove_elements(date_list, filter_date_list)
|
|
|
return date_list
|
|
|
|
|
|
|
|
|
-def update_train_tables(old_str):
|
|
|
+def update_data_date_range(old_str):
|
|
|
date_list = get_train_data_list()
|
|
|
train_list = ["'" + item + "'" for item in date_list]
|
|
|
result = ','.join(train_list)
|
|
@@ -399,7 +400,6 @@ def get_online_model_config(service_name: str):
|
|
|
return model_config
|
|
|
|
|
|
|
|
|
-
|
|
|
def update_shuffle_flow(table):
|
|
|
draft = PAIClient.get_work_flow_draft(experiment_id)
|
|
|
print(json.dumps(draft, ensure_ascii=False))
|
|
@@ -437,7 +437,7 @@ def update_shuffle_flow_1():
|
|
|
for property in properties:
|
|
|
if property['name'] == 'sql':
|
|
|
value = property['value']
|
|
|
- new_value = update_train_tables(value)
|
|
|
+ new_value = update_data_date_range(value)
|
|
|
if new_value is None:
|
|
|
print("error")
|
|
|
property['value'] = new_value
|
|
@@ -445,7 +445,7 @@ def update_shuffle_flow_1():
|
|
|
PAIClient.update_experiment_content(experiment_id, new_content, version)
|
|
|
|
|
|
|
|
|
-def wait_job_end(job_id: str):
|
|
|
+def wait_job_end(job_id: str, check_interval=300):
|
|
|
while True:
|
|
|
job_detail = PAIClient.get_job_detail(job_id)
|
|
|
print(job_detail)
|
|
@@ -453,8 +453,7 @@ def wait_job_end(job_id: str):
|
|
|
# Initialized: 初始化完成 Starting:开始 WorkflowServiceStarting:准备提交 Running:运行中 ReadyToSchedule:准备运行(前序节点未完成导致)
|
|
|
if (statue == 'Initialized' or statue == 'Starting' or statue == 'WorkflowServiceStarting'
|
|
|
or statue == 'Running' or statue == 'ReadyToSchedule'):
|
|
|
- # 睡眠300s 等待下次获取
|
|
|
- time.sleep(300)
|
|
|
+ time.sleep(check_interval)
|
|
|
continue
|
|
|
# Failed:运行失败 Terminating:终止中 Terminated:已终止 Unknown:未知 Skipped:跳过(前序节点失败导致) Succeeded:运行成功
|
|
|
if statue == 'Failed' or statue == 'Terminating' or statue == 'Unknown' or statue == 'Skipped' or statue == 'Succeeded':
|
|
@@ -521,12 +520,12 @@ def update_online_flow():
|
|
|
for node in nodes:
|
|
|
try:
|
|
|
name = node['name']
|
|
|
- if name == '样本shuffle':
|
|
|
+ if name in ('样本shuffle', '评估样本重组'):
|
|
|
properties = node['properties']
|
|
|
for property in properties:
|
|
|
if property['name'] == 'sql':
|
|
|
value = property['value']
|
|
|
- new_value = update_train_tables(value)
|
|
|
+ new_value = update_data_date_range(value)
|
|
|
if new_value is None:
|
|
|
print("error")
|
|
|
property['value'] = new_value
|
|
@@ -540,15 +539,36 @@ def update_online_flow():
|
|
|
except Exception as e:
|
|
|
raise Exception(f"发生未知错误: {e}")
|
|
|
|
|
|
+def update_global_param(params):
|
|
|
+ try:
|
|
|
+ draft = PAIClient.get_work_flow_draft(experiment_id)
|
|
|
+ content = draft['Content']
|
|
|
+ version = draft['Version']
|
|
|
+ content_json = json.loads(content)
|
|
|
+ nodes = content_json.get('nodes')
|
|
|
+ global_params = content_json.get('globalParams')
|
|
|
+ for global_param in global_params:
|
|
|
+ if global_param['name'] in params:
|
|
|
+ value = params[global_param['name']]
|
|
|
+ print(f"update global param {global_param['name']}: {value}")
|
|
|
+ global_param['value'] = value
|
|
|
+ new_content = json.dumps(content_json, ensure_ascii=False)
|
|
|
+ PAIClient.update_experiment_content(experiment_id, new_content, version)
|
|
|
+ return True
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ raise Exception("JSON 解析错误,可能是草稿内容格式不正确")
|
|
|
+ except Exception as e:
|
|
|
+ raise Exception(f"发生未知错误: {e}")
|
|
|
+
|
|
|
@retry
|
|
|
def shuffle_table():
|
|
|
try:
|
|
|
node_dict = get_node_dict()
|
|
|
train_node_id = node_dict['样本shuffle']
|
|
|
- execute_type = 'EXECUTE_ONE'
|
|
|
+ execute_type = 'EXECUTE_FROM_HERE'
|
|
|
validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
|
|
|
validate_job_id = validate_res['JobId']
|
|
|
- validate_job_detail = wait_job_end(validate_job_id)
|
|
|
+ validate_job_detail = wait_job_end(validate_job_id, 10)
|
|
|
if validate_job_detail['Status'] == 'Succeeded':
|
|
|
return True
|
|
|
return False
|
|
@@ -569,11 +589,11 @@ def shuffle_train_model():
|
|
|
pipeline_run_id = validate_job_detail['RunId']
|
|
|
node_id = validate_job_detail['PaiflowNodeId']
|
|
|
flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 2)
|
|
|
- out_puts = flow_out_put_detail['Outputs']
|
|
|
+ outputs = flow_out_put_detail['Outputs']
|
|
|
table = None
|
|
|
- for out_put in out_puts:
|
|
|
- if out_put["Producer"] == node_dict['样本shuffle'] and out_put["Name"] == "outputTable":
|
|
|
- value1 = json.loads(out_put["Info"]['value'])
|
|
|
+ for output in outputs:
|
|
|
+ if output["Producer"] == node_dict['样本shuffle'] and output["Name"] == "outputTable":
|
|
|
+ value1 = json.loads(output["Info"]['value'])
|
|
|
table = value1['location']['table']
|
|
|
if table is not None:
|
|
|
update_shuffle_flow(table)
|
|
@@ -626,9 +646,34 @@ def update_online_model():
|
|
|
print(error_message)
|
|
|
raise Exception(error_message)
|
|
|
|
|
|
+def update_validation_config():
|
|
|
+ try:
|
|
|
+ job_dict = get_job_dict()
|
|
|
+ node_dict = get_node_dict()
|
|
|
+ print(node_dict)
|
|
|
+ job_id = job_dict['样本shuffle']
|
|
|
+ validate_job_detail = wait_job_end(job_id)
|
|
|
+ table = None
|
|
|
+ if validate_job_detail['Status'] == 'Succeeded':
|
|
|
+ pipeline_run_id = validate_job_detail['RunId']
|
|
|
+ node_id = validate_job_detail['PaiflowNodeId']
|
|
|
+ flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 2)
|
|
|
+ outputs = flow_out_put_detail['Outputs']
|
|
|
+ for output in outputs:
|
|
|
+ if output["Producer"] == node_dict['评估样本重组'] and output["Name"] == "outputTable":
|
|
|
+ value1 = json.loads(output["Info"]['value'])
|
|
|
+ table = value1['location']['table']
|
|
|
+ if not table:
|
|
|
+ raise Exception("table not available")
|
|
|
+ update_global_param({'eval_table_name': table})
|
|
|
+ except Exception as e:
|
|
|
+ error_message = f"在执行 update_validation_config 函数时发生异常: {str(e)}"
|
|
|
+ print(error_message)
|
|
|
+ raise Exception(error_message)
|
|
|
|
|
|
@retry
|
|
|
def get_validate_model_data():
|
|
|
+ update_validation_config()
|
|
|
try:
|
|
|
node_dict = get_node_dict()
|
|
|
train_node_id = node_dict['虚拟起始节点']
|
|
@@ -657,16 +702,16 @@ def validate_model_data_accuracy():
|
|
|
node_id = validate_job_detail['PaiflowNodeId']
|
|
|
flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 3)
|
|
|
print(flow_out_put_detail)
|
|
|
- out_puts = flow_out_put_detail['Outputs']
|
|
|
- for out_put in out_puts:
|
|
|
- if out_put["Producer"] == node_dict['二分类评估-1'] and out_put["Name"] == "outputMetricTable":
|
|
|
- value1 = json.loads(out_put["Info"]['value'])
|
|
|
+ outputs = flow_out_put_detail['Outputs']
|
|
|
+ for output in outputs:
|
|
|
+ if output["Producer"] == node_dict['二分类评估-1'] and output["Name"] == "outputMetricTable":
|
|
|
+ value1 = json.loads(output["Info"]['value'])
|
|
|
table_dict['二分类评估-1'] = value1['location']['table']
|
|
|
- if out_put["Producer"] == node_dict['二分类评估-2'] and out_put["Name"] == "outputMetricTable":
|
|
|
- value2 = json.loads(out_put["Info"]['value'])
|
|
|
+ if output["Producer"] == node_dict['二分类评估-2'] and output["Name"] == "outputMetricTable":
|
|
|
+ value2 = json.loads(output["Info"]['value'])
|
|
|
table_dict['二分类评估-2'] = value2['location']['table']
|
|
|
- if out_put["Producer"] == node_dict['预测结果对比'] and out_put["Name"] == "outputTable":
|
|
|
- value3 = json.loads(out_put["Info"]['value'])
|
|
|
+ if output["Producer"] == node_dict['预测结果对比'] and output["Name"] == "outputTable":
|
|
|
+ value3 = json.loads(output["Info"]['value'])
|
|
|
table_dict['预测结果对比'] = value3['location']['table']
|
|
|
num = 10
|
|
|
df = get_data_from_odps('pai_algo', table_dict['预测结果对比'], 10)
|