Browse Source

Update ad/pai_flow_operator: new pipeline

fengzhoutian 5 days ago
parent
commit
98e9c93b7d
1 changed files with 68 additions and 23 deletions
  1. 68 23
      ad/pai_flow_operator.py

+ 68 - 23
ad/pai_flow_operator.py

@@ -20,6 +20,7 @@ from ad_monitor_util import _monitor
 
 target_names = {
     '样本shuffle',
+    '评估样本重组',
     '模型训练-样本shufle',
     '模型训练-自定义',
     '模型增量训练',
@@ -143,14 +144,14 @@ def process_list(lst, append_str):
 
 def get_train_data_list():
     start_date = '20250320'
-    end_date = get_previous_days_date(2)
+    end_date = get_previous_days_date(1)
     date_list = get_dates_between(start_date, end_date)
     filter_date_list = read_file_to_list()
     date_list = remove_elements(date_list, filter_date_list)
     return date_list
 
 
-def update_train_tables(old_str):
+def update_data_date_range(old_str):
     date_list = get_train_data_list()
     train_list = ["'" + item + "'" for item in date_list]
     result = ','.join(train_list)
@@ -399,7 +400,6 @@ def get_online_model_config(service_name: str):
     return model_config
 
 
-
 def update_shuffle_flow(table):
     draft = PAIClient.get_work_flow_draft(experiment_id)
     print(json.dumps(draft, ensure_ascii=False))
@@ -437,7 +437,7 @@ def update_shuffle_flow_1():
             for property in properties:
                 if property['name'] == 'sql':
                     value = property['value']
-                    new_value = update_train_tables(value)
+                    new_value = update_data_date_range(value)
                     if new_value is None:
                         print("error")
                     property['value'] = new_value
@@ -445,7 +445,7 @@ def update_shuffle_flow_1():
     PAIClient.update_experiment_content(experiment_id, new_content, version)
 
 
-def wait_job_end(job_id: str):
+def wait_job_end(job_id: str, check_interval=300):
     while True:
         job_detail = PAIClient.get_job_detail(job_id)
         print(job_detail)
@@ -453,8 +453,7 @@ def wait_job_end(job_id: str):
         # Initialized: 初始化完成 Starting:开始 WorkflowServiceStarting:准备提交 Running:运行中 ReadyToSchedule:准备运行(前序节点未完成导致)
         if (statue == 'Initialized' or statue == 'Starting' or statue == 'WorkflowServiceStarting'
                 or statue == 'Running' or statue == 'ReadyToSchedule'):
-            # 睡眠300s 等待下次获取
-            time.sleep(300)
+            time.sleep(check_interval)
             continue
         # Failed:运行失败 Terminating:终止中 Terminated:已终止 Unknown:未知 Skipped:跳过(前序节点失败导致) Succeeded:运行成功
         if statue == 'Failed' or statue == 'Terminating' or statue == 'Unknown' or statue == 'Skipped' or statue == 'Succeeded':
@@ -521,12 +520,12 @@ def update_online_flow():
         for node in nodes:
             try:
                 name = node['name']
-                if name == '样本shuffle':
+                if name in ('样本shuffle', '评估样本重组'):
                     properties = node['properties']
                     for property in properties:
                         if property['name'] == 'sql':
                             value = property['value']
-                            new_value = update_train_tables(value)
+                            new_value = update_data_date_range(value)
                             if new_value is None:
                                 print("error")
                             property['value'] = new_value
@@ -540,15 +539,36 @@ def update_online_flow():
     except Exception as e:
         raise Exception(f"发生未知错误: {e}")
 
+def update_global_param(params):
+    try:
+        draft = PAIClient.get_work_flow_draft(experiment_id)
+        content = draft['Content']
+        version = draft['Version']
+        content_json = json.loads(content)
+        nodes = content_json.get('nodes')
+        global_params = content_json.get('globalParams')
+        for global_param in global_params:
+            if global_param['name'] in params:
+                value = params[global_param['name']]
+                print(f"update global param {global_param['name']}: {value}")
+                global_param['value'] = value
+        new_content = json.dumps(content_json, ensure_ascii=False)
+        PAIClient.update_experiment_content(experiment_id, new_content, version)
+        return True
+    except json.JSONDecodeError:
+        raise Exception("JSON 解析错误,可能是草稿内容格式不正确")
+    except Exception as e:
+        raise Exception(f"发生未知错误: {e}")
+
 @retry
 def shuffle_table():
     try:
         node_dict = get_node_dict()
         train_node_id = node_dict['样本shuffle']
-        execute_type = 'EXECUTE_ONE'
+        execute_type = 'EXECUTE_FROM_HERE'
         validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
         validate_job_id = validate_res['JobId']
-        validate_job_detail = wait_job_end(validate_job_id)
+        validate_job_detail = wait_job_end(validate_job_id, 10)
         if validate_job_detail['Status'] == 'Succeeded':
             return True
         return False
@@ -569,11 +589,11 @@ def shuffle_train_model():
             pipeline_run_id = validate_job_detail['RunId']
             node_id = validate_job_detail['PaiflowNodeId']
             flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 2)
-            out_puts = flow_out_put_detail['Outputs']
+            outputs = flow_out_put_detail['Outputs']
             table = None
-            for out_put in out_puts:
-                if out_put["Producer"] == node_dict['样本shuffle'] and out_put["Name"] == "outputTable":
-                    value1 = json.loads(out_put["Info"]['value'])
+            for output in outputs:
+                if output["Producer"] == node_dict['样本shuffle'] and output["Name"] == "outputTable":
+                    value1 = json.loads(output["Info"]['value'])
                     table = value1['location']['table']
             if table is not None:
                 update_shuffle_flow(table)
@@ -626,9 +646,34 @@ def update_online_model():
         print(error_message)
         raise Exception(error_message)
 
+def update_validation_config():
+    try:
+        job_dict = get_job_dict()
+        node_dict = get_node_dict()
+        print(node_dict)
+        job_id = job_dict['样本shuffle']
+        validate_job_detail = wait_job_end(job_id)
+        table = None
+        if validate_job_detail['Status'] == 'Succeeded':
+            pipeline_run_id = validate_job_detail['RunId']
+            node_id = validate_job_detail['PaiflowNodeId']
+            flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 2)
+            outputs = flow_out_put_detail['Outputs']
+            for output in outputs:
+                if output["Producer"] == node_dict['评估样本重组'] and output["Name"] == "outputTable":
+                    value1 = json.loads(output["Info"]['value'])
+                    table = value1['location']['table']
+        if not table:
+            raise Exception("table not available")
+        update_global_param({'eval_table_name': table})
+    except Exception as e:
+        error_message = f"在执行 update_validation_config 函数时发生异常: {str(e)}"
+        print(error_message)
+        raise Exception(error_message)
 
 @retry
 def get_validate_model_data():
+    update_validation_config()
     try:
         node_dict = get_node_dict()
         train_node_id = node_dict['虚拟起始节点']
@@ -657,16 +702,16 @@ def validate_model_data_accuracy():
             node_id = validate_job_detail['PaiflowNodeId']
             flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 3)
             print(flow_out_put_detail)
-            out_puts = flow_out_put_detail['Outputs']
-            for out_put in out_puts:
-                if out_put["Producer"] == node_dict['二分类评估-1'] and out_put["Name"] == "outputMetricTable":
-                    value1 = json.loads(out_put["Info"]['value'])
+            outputs = flow_out_put_detail['Outputs']
+            for output in outputs:
+                if output["Producer"] == node_dict['二分类评估-1'] and output["Name"] == "outputMetricTable":
+                    value1 = json.loads(output["Info"]['value'])
                     table_dict['二分类评估-1'] = value1['location']['table']
-                if out_put["Producer"] == node_dict['二分类评估-2'] and out_put["Name"] == "outputMetricTable":
-                    value2 = json.loads(out_put["Info"]['value'])
+                if output["Producer"] == node_dict['二分类评估-2'] and output["Name"] == "outputMetricTable":
+                    value2 = json.loads(output["Info"]['value'])
                     table_dict['二分类评估-2'] = value2['location']['table']
-                if out_put["Producer"] == node_dict['预测结果对比'] and out_put["Name"] == "outputTable":
-                    value3 = json.loads(out_put["Info"]['value'])
+                if output["Producer"] == node_dict['预测结果对比'] and output["Name"] == "outputTable":
+                    value3 = json.loads(output["Info"]['value'])
                     table_dict['预测结果对比'] = value3['location']['table']
         num = 10
         df = get_data_from_odps('pai_algo', table_dict['预测结果对比'], 10)