瀏覽代碼

Move pai_flow_operator2 to pai_flow_operator

fengzhoutian 5 天之前
父節點
當前提交
216a14d879
共有 3 個文件被更改,包括 315 次插入970 次删除
  1. 1 1
      ad/02_ad_model_dnn_v11_update.sh
  2. 314 221
      ad/pai_flow_operator.py
  3. 0 748
      ad/pai_flow_operator2.py

+ 1 - 1
ad/02_ad_model_dnn_v11_update.sh

@@ -218,7 +218,7 @@ bucket_feature_from_origin_to_hive() {
 run_pai_flow() {
   local step_start_time=$(date +%s)
 
-  python ad/pai_flow_operator2.py
+  python ad/pai_flow_operator.py
   
   local return_code=$?
   check_run_status ${return_code} ${step_start_time} "PAI工作流任务" "PAI工作流执行失败"

+ 314 - 221
ad/pai_flow_operator.py

@@ -1,11 +1,8 @@
 # -*- coding: utf-8 -*-
+import functools
 import os
 import re
 import sys
-
-from typing import List
-
-
 import time
 import json
 import pandas as pd
@@ -34,7 +31,28 @@ target_names = {
     '预测结果对比'
 }
 
-experiment_id = "draft-kbezr8f0q3cpee9eqc"
+experiment_id = "draft-wqgkag89sbh9v1zvut"
+
+MAX_RETRIES = 3
+
+
+def retry(func):
+    @functools.wraps(func)
+    def wrapper(*args, **kwargs):
+        retries = 0
+        while retries < MAX_RETRIES:
+            try:
+                result = func(*args, **kwargs)
+                if result is not False:
+                    return result
+            except Exception as e:
+                print(f"函数 {func.__name__} 执行时发生异常: {e},重试第 {retries + 1} 次")
+            retries += 1
+        print(f"函数 {func.__name__} 重试 {MAX_RETRIES} 次后仍失败。")
+        return False
+
+    return wrapper
+
 
 def get_odps_instance(project):
     odps = ODPS(
@@ -53,21 +71,10 @@ def get_data_from_odps(project, table, num):
         sql = f'select * from {table} limit {num}'
         # 执行 SQL 查询
         with odps.execute_sql(sql).open_reader() as reader:
+            df = reader.to_pandas()
             # 查询数量小于目标数量时 返回空
-            if reader.count < num:
+            if len(df) < num:
                 return None
-            # 获取字段名称
-            column_names = reader.schema.names
-            # 获取查询结果数据
-            data = []
-            for record in reader:
-                record_list = list(record)
-                numbers = []
-                for item in record_list:
-                    numbers.append(item[1])
-                data.append(numbers)
-            # 将数据和字段名称组合成 DataFrame
-            df = pd.DataFrame(data, columns=column_names)
             return df
     except Exception as e:
         print(f"发生错误: {e}")
@@ -105,14 +112,14 @@ def get_dates_between(start_date_str, end_date_str):
 def read_file_to_list():
     try:
         current_dir = os.getcwd()
-        file_path = os.path.join(current_dir, 'holidays.txt')
+        file_path = os.path.join(current_dir, 'ad', 'holidays.txt')
         with open(file_path, 'r', encoding='utf-8') as file:
             content = file.read()
             return content.split('\n')
     except FileNotFoundError:
-        print(f"错误:未找到 {file_path} 文件。")
+        raise Exception(f"错误:未找到 {file_path} 文件。")
     except Exception as e:
-        print(f"错误:发生了一个未知错误: {e}")
+        raise Exception(f"错误:发生了一个未知错误: {e}")
     return []
 
 
@@ -135,7 +142,7 @@ def process_list(lst, append_str):
 
 
 def get_train_data_list():
-    start_date = '20250223'
+    start_date = '20250320'
     end_date = get_previous_days_date(2)
     date_list = get_dates_between(start_date, end_date)
     filter_date_list = read_file_to_list()
@@ -145,32 +152,44 @@ def get_train_data_list():
 
 def update_train_tables(old_str):
     date_list = get_train_data_list()
-    address = 'odps://loghubods/tables/ad_easyrec_train_data_v3_sampled/dt='
-    train_tables = process_list(date_list, address)
-    start_index = old_str.find('-Dtrain_tables="')
+    train_list = ["'" + item + "'" for item in date_list]
+    result = ','.join(train_list)
+    start_index = old_str.find('where dt in (')
     if start_index != -1:
-        # 确定等号的位置
-        equal_sign_index = start_index + len('-Dtrain_tables="')
+        equal_sign_index = start_index + len('where dt in (')
         # 找到下一个双引号的位置
-        next_quote_index = old_str.find('"', equal_sign_index)
+        next_quote_index = old_str.find(')', equal_sign_index)
         if next_quote_index != -1:
             # 进行替换
-            new_value = old_str[:equal_sign_index] + train_tables + old_str[next_quote_index:]
+            new_value = old_str[:equal_sign_index] + result + old_str[next_quote_index:]
             return new_value
     return None
 
-def new_update_train_tables(old_str):
-    date_list = get_train_data_list()
-    train_list = ["'" + item + "'" for item in date_list]
-    result = ','.join(train_list)
-    start_index = old_str.find('where dt in (')
+
+def compare_timestamp_with_today_start(time_str):
+    # 解析时间字符串为 datetime 对象
+    time_obj = datetime.fromisoformat(time_str)
+    # 将其转换为时间戳
+    target_timestamp = time_obj.timestamp()
+    # 获取今天开始的时间
+    today_start = datetime.combine(datetime.now().date(), datetime.min.time())
+    # 将今天开始时间转换为时间戳
+    today_start_timestamp = today_start.timestamp()
+    return target_timestamp > today_start_timestamp
+
+
+def update_train_table(old_str, table):
+    address = 'odps://pai_algo/tables/'
+    train_table = address + table
+    start_index = old_str.find('-Dtrain_tables="')
     if start_index != -1:
-        equal_sign_index = start_index + len('where dt in (')
+        # 确定等号的位置
+        equal_sign_index = start_index + len('-Dtrain_tables="')
         # 找到下一个双引号的位置
-        next_quote_index = old_str.find(')', equal_sign_index)
+        next_quote_index = old_str.find('"', equal_sign_index)
         if next_quote_index != -1:
             # 进行替换
-            new_value = old_str[:equal_sign_index] + result + old_str[next_quote_index:]
+            new_value = old_str[:equal_sign_index] + train_table + old_str[next_quote_index:]
             return new_value
     return None
 
@@ -244,9 +263,7 @@ class PAIClient:
             resp = client.list_experiments_with_options(list_experiments_request, headers, runtime)
             return resp.body.to_map()
         except Exception as error:
-            print(error.message)
-            print(error.data.get("Recommend"))
-            UtilClient.assert_as_string(error.message)
+            raise Exception(f"get_work_flow_draft_list error {error}")
 
     @staticmethod
     def get_work_flow_draft(experiment_id: str):
@@ -258,12 +275,7 @@ class PAIClient:
             resp = client.get_experiment_with_options(experiment_id, headers, runtime)
             return resp.body.to_map()
         except Exception as error:
-            # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
-            # 错误 message
-            print(error.message)
-            # 诊断地址
-            print(error.data.get("Recommend"))
-            UtilClient.assert_as_string(error.message)
+            raise Exception(f"get_work_flow_draft error {error}")
 
     @staticmethod
     def get_describe_service(service_name: str):
@@ -275,12 +287,7 @@ class PAIClient:
             resp = client.describe_service_with_options('cn-hangzhou', service_name, headers, runtime)
             return resp.body.to_map()
         except Exception as error:
-            # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
-            # 错误 message
-            print(error.message)
-            # 诊断地址
-            print(error.data.get("Recommend"))
-            UtilClient.assert_as_string(error.message)
+            raise Exception(f"get_describe_service error {error}")
 
     @staticmethod
     def update_experiment_content(experiment_id: str, content: str, version: int):
@@ -295,12 +302,7 @@ class PAIClient:
                                                                  headers, runtime)
             print(resp.body.to_map())
         except Exception as error:
-            # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
-            # 错误 message
-            print(error.message)
-            # 诊断地址
-            print(error.data.get("Recommend"))
-            UtilClient.assert_as_string(error.message)
+            raise Exception(f"update_experiment_content error {error}")
 
     @staticmethod
     def create_job(experiment_id: str, node_id: str, execute_type: str):
@@ -316,18 +318,29 @@ class PAIClient:
             resp = client.create_job_with_options(create_job_request, headers, runtime)
             return resp.body.to_map()
         except Exception as error:
-            # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
-            # 错误 message
-            print(error.message)
-            # 诊断地址
-            print(error.data.get("Recommend"))
-            UtilClient.assert_as_string(error.message)
+            raise Exception(f"create_job error {error}")
 
     @staticmethod
-    def get_job_detail(job_id: str):
+    def get_jobs_list(experiment_id: str, order='DESC'):
+        client = PAIClient.create_client()
+        list_jobs_request = pai_studio_20210202_models.ListJobsRequest(
+            experiment_id=experiment_id,
+            order=order
+        )
+        runtime = util_models.RuntimeOptions()
+        headers = {}
+        try:
+            # 复制代码运行请自行打印 API 的返回值
+            resp = client.list_jobs_with_options(list_jobs_request, headers, runtime)
+            return resp.body.to_map()
+        except Exception as error:
+            raise Exception(f"get_jobs_list error {error}")
+
+    @staticmethod
+    def get_job_detail(job_id: str, verbose=False):
         client = PAIClient.create_client()
         get_job_request = pai_studio_20210202_models.GetJobRequest(
-            verbose=False
+            verbose=verbose
         )
         runtime = util_models.RuntimeOptions()
         headers = {}
@@ -374,41 +387,34 @@ def extract_date_yyyymmdd(input_string):
     return None
 
 
-def get_online_version_dt(service_name: str):
+def get_online_model_config(service_name: str):
+    model_config = {}
     model_detail = PAIClient.get_describe_service(service_name)
     service_config_str = model_detail['ServiceConfig']
     service_config = json.loads(service_config_str)
     model_path = service_config['model_path']
+    model_config['model_path'] = model_path
     online_date = extract_date_yyyymmdd(model_path)
-    return online_date
+    model_config['online_date'] = online_date
+    return model_config
 
 
-def update_online_flow():
-    online_version_dt = get_online_version_dt('ad_rank_dnn_v11_easyrec')
+
+def update_shuffle_flow(table):
     draft = PAIClient.get_work_flow_draft(experiment_id)
     print(json.dumps(draft, ensure_ascii=False))
     content = draft['Content']
     version = draft['Version']
-    print(content)
     content_json = json.loads(content)
     nodes = content_json.get('nodes')
-    global_params = content_json.get('globalParams')
-    bizdate = get_previous_days_date(1)
-    for global_param in global_params:
-        if global_param['name'] == 'bizdate':
-            global_param['value'] = bizdate
-        if global_param['name'] == 'online_version_dt':
-            global_param['value'] = online_version_dt
-        if global_param['name'] == 'eval_date':
-            global_param['value'] = bizdate
     for node in nodes:
         name = node['name']
-        if name == '模型训练-自定义':
+        if name == '模型训练-样本shufle':
             properties = node['properties']
             for property in properties:
                 if property['name'] == 'sql':
                     value = property['value']
-                    new_value = update_train_tables(value)
+                    new_value = update_train_table(value, table)
                     if new_value is None:
                         print("error")
                     property['value'] = new_value
@@ -416,8 +422,7 @@ def update_online_flow():
     PAIClient.update_experiment_content(experiment_id, new_content, version)
 
 
-def update_online_new_flow():
-    online_version_dt = get_online_version_dt('ad_rank_dnn_v11_easyrec')
+def update_shuffle_flow_1():
     draft = PAIClient.get_work_flow_draft(experiment_id)
     print(json.dumps(draft, ensure_ascii=False))
     content = draft['Content']
@@ -425,23 +430,14 @@ def update_online_new_flow():
     print(content)
     content_json = json.loads(content)
     nodes = content_json.get('nodes')
-    global_params = content_json.get('globalParams')
-    bizdate = get_previous_days_date(1)
-    for global_param in global_params:
-        if global_param['name'] == 'bizdate':
-            global_param['value'] = bizdate
-        if global_param['name'] == 'online_version_dt':
-            global_param['value'] = online_version_dt
-        if global_param['name'] == 'eval_date':
-            global_param['value'] = bizdate
     for node in nodes:
         name = node['name']
-        if name == '样本shuffle':
+        if name == '模型训练-样本shufle':
             properties = node['properties']
             for property in properties:
                 if property['name'] == 'sql':
                     value = property['value']
-                    new_value = new_update_train_tables(value)
+                    new_value = update_train_tables(value)
                     if new_value is None:
                         print("error")
                     property['value'] = new_value
@@ -466,7 +462,6 @@ def wait_job_end(job_id: str):
 
 
 def get_node_dict():
-    experiment_id = "draft-7u3e9v1uc5pohjl0t6"
     draft = PAIClient.get_work_flow_draft(experiment_id)
     content = draft['Content']
     content_json = json.loads(content)
@@ -480,76 +475,210 @@ def get_node_dict():
     return node_dict
 
 
-def train_model():
-    node_dict = get_node_dict()
-    train_node_id = node_dict['模型训练-自定义']
-    execute_type = 'EXECUTE_ONE'
-    train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
-    train_job_id = train_res['JobId']
-    train_job_detail = wait_job_end(train_job_id)
-    if train_job_detail['Status'] == 'Succeeded':
+def get_job_dict():
+    job_dict = {}
+    jobs_list = PAIClient.get_jobs_list(experiment_id)
+    for job in jobs_list['Jobs']:
+        # 解析时间字符串为 datetime 对象
+        if not compare_timestamp_with_today_start(job['GmtCreateTime']):
+            break
+        job_id = job['JobId']
+        job_detail = PAIClient.get_job_detail(job_id, verbose=True)
+        for name in target_names:
+            if job_detail['Status'] != 'Succeeded':
+                continue
+            if name in job_dict:
+                continue
+            if name in job_detail['RunInfo']:
+                job_dict[name] = job_detail['JobId']
+    return job_dict
+
+@retry
+def update_online_flow():
+    try:
+        online_model_config = get_online_model_config('ad_rank_dnn_v11_easyrec')
+        draft = PAIClient.get_work_flow_draft(experiment_id)
+        print(json.dumps(draft, ensure_ascii=False))
+        content = draft['Content']
+        version = draft['Version']
+        print(content)
+        content_json = json.loads(content)
+        nodes = content_json.get('nodes')
+        global_params = content_json.get('globalParams')
+        bizdate = get_previous_days_date(1)
+        for global_param in global_params:
+            try:
+                if global_param['name'] == 'bizdate':
+                    global_param['value'] = bizdate
+                if global_param['name'] == 'online_version_dt':
+                    global_param['value'] = online_model_config['online_date']
+                if global_param['name'] == 'eval_date':
+                    global_param['value'] = bizdate
+                if global_param['name'] == 'online_model_path':
+                    global_param['value'] = online_model_config['model_path']
+            except KeyError:
+                raise Exception("在处理全局参数时,字典中缺少必要的键")
+        for node in nodes:
+            try:
+                name = node['name']
+                if name == '样本shuffle':
+                    properties = node['properties']
+                    for property in properties:
+                        if property['name'] == 'sql':
+                            value = property['value']
+                            new_value = update_train_tables(value)
+                            if new_value is None:
+                                print("error")
+                            property['value'] = new_value
+            except KeyError:
+                raise Exception("在处理节点属性时,字典中缺少必要的键")
+        new_content = json.dumps(content_json, ensure_ascii=False)
+        PAIClient.update_experiment_content(experiment_id, new_content, version)
+        return True
+    except json.JSONDecodeError:
+        raise Exception("JSON 解析错误,可能是草稿内容格式不正确")
+    except Exception as e:
+        raise Exception(f"发生未知错误: {e}")
+
+@retry
+def shuffle_table():
+    try:
+        node_dict = get_node_dict()
+        train_node_id = node_dict['样本shuffle']
+        execute_type = 'EXECUTE_ONE'
+        validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
+        validate_job_id = validate_res['JobId']
+        validate_job_detail = wait_job_end(validate_job_id)
+        if validate_job_detail['Status'] == 'Succeeded':
+            return True
+        return False
+    except Exception as e:
+        error_message = f"在执行 shuffle_table 函数时发生异常: {str(e)}"
+        print(error_message)
+        raise Exception(error_message)
+
+
+@retry
+def shuffle_train_model():
+    try:
+        node_dict = get_node_dict()
+        job_dict = get_job_dict()
+        job_id = job_dict['样本shuffle']
+        validate_job_detail = wait_job_end(job_id)
+        if validate_job_detail['Status'] == 'Succeeded':
+            pipeline_run_id = validate_job_detail['RunId']
+            node_id = validate_job_detail['PaiflowNodeId']
+            flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 2)
+            out_puts = flow_out_put_detail['Outputs']
+            table = None
+            for out_put in out_puts:
+                if out_put["Producer"] == node_dict['样本shuffle'] and out_put["Name"] == "outputTable":
+                    value1 = json.loads(out_put["Info"]['value'])
+                    table = value1['location']['table']
+            if table is not None:
+                update_shuffle_flow(table)
+                node_dict = get_node_dict()
+                train_node_id = node_dict['模型训练-样本shufle']
+                execute_type = 'EXECUTE_ONE'
+                train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
+                train_job_id = train_res['JobId']
+                train_job_detail = wait_job_end(train_job_id)
+                if train_job_detail['Status'] == 'Succeeded':
+                    return True
+        return False
+    except Exception as e:
+        error_message = f"在执行 shuffle_train_model 函数时发生异常: {str(e)}"
+        print(error_message)
+        raise Exception(error_message)
+
+
+@retry
+def export_model():
+    try:
+        node_dict = get_node_dict()
         export_node_id = node_dict['模型导出-2']
+        execute_type = 'EXECUTE_ONE'
         export_res = PAIClient.create_job(experiment_id, export_node_id, execute_type)
         export_job_id = export_res['JobId']
         export_job_detail = wait_job_end(export_job_id)
         if export_job_detail['Status'] == 'Succeeded':
             return True
-    return False
+        return False
+    except Exception as e:
+        error_message = f"在执行 export_model 函数时发生异常: {str(e)}"
+        print(error_message)
+        raise Exception(error_message)
 
 
 def update_online_model():
-    node_dict = get_node_dict()
-    experiment_id = "draft-7u3e9v1uc5pohjl0t6"
-    train_node_id = node_dict['更新EAS服务(Beta)-1']
-    execute_type = 'EXECUTE_ONE'
-    train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
-    train_job_id = train_res['JobId']
-    train_job_detail = wait_job_end(train_job_id)
-    if train_job_detail['Status'] == 'Succeeded':
-        return True
-    return False
+    try:
+        node_dict = get_node_dict()
+        train_node_id = node_dict['更新EAS服务(Beta)-1']
+        execute_type = 'EXECUTE_ONE'
+        train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
+        train_job_id = train_res['JobId']
+        train_job_detail = wait_job_end(train_job_id)
+        if train_job_detail['Status'] == 'Succeeded':
+            return True
+        return False
+    except Exception as e:
+        error_message = f"在执行 update_online_model 函数时发生异常: {str(e)}"
+        print(error_message)
+        raise Exception(error_message)
 
 
-def validate_model_data_accuracy():
-    node_dict = get_node_dict()
-    train_node_id = node_dict['虚拟起始节点']
-    execute_type = 'EXECUTE_FROM_HERE'
-    validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
-    print(validate_res)
-    validate_job_id = validate_res['JobId']
-    print(validate_job_id)
-    validate_job_detail = wait_job_end(validate_job_id)
-    print('res')
-    print(validate_job_detail)
-    if validate_job_detail['Status'] == 'Succeeded':
-        pipeline_run_id = validate_job_detail['RunId']
-        node_id = validate_job_detail['PaiflowNodeId']
-        flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 3)
-        print(flow_out_put_detail)
-        tabel_dict = {}
-        out_puts = flow_out_put_detail['Outputs']
-        for out_put in out_puts:
-            if out_put["Producer"] == node_dict['二分类评估-1'] and out_put["Name"] == "outputMetricTable":
-                value1 = json.loads(out_put["Info"]['value'])
-                tabel_dict['二分类评估-1'] = value1['location']['table']
-            if out_put["Producer"] == node_dict['二分类评估-2'] and out_put["Name"] == "outputMetricTable":
-                value2 = json.loads(out_put["Info"]['value'])
-                tabel_dict['二分类评估-2'] = value2['location']['table']
-            if out_put["Producer"] == node_dict['预测结果对比'] and out_put["Name"] == "outputTable":
-                value3 = json.loads(out_put["Info"]['value'])
-                tabel_dict['预测结果对比'] = value3['location']['table']
+@retry
+def get_validate_model_data():
+    try:
+        node_dict = get_node_dict()
+        train_node_id = node_dict['虚拟起始节点']
+        execute_type = 'EXECUTE_FROM_HERE'
+        validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
+        validate_job_id = validate_res['JobId']
+        validate_job_detail = wait_job_end(validate_job_id)
+        if validate_job_detail['Status'] == 'Succeeded':
+            return True
+        return False
+    except Exception as e:
+        error_message = f"在执行 get_validate_model_data 函数时出现异常: {e}"
+        print(error_message)
+        raise Exception(error_message)
+
 
+def validate_model_data_accuracy():
+    try:
+        table_dict = {}
+        node_dict = get_node_dict()
+        job_dict = get_job_dict()
+        job_id = job_dict['虚拟起始节点']
+        validate_job_detail = wait_job_end(job_id)
+        if validate_job_detail['Status'] == 'Succeeded':
+            pipeline_run_id = validate_job_detail['RunId']
+            node_id = validate_job_detail['PaiflowNodeId']
+            flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 3)
+            print(flow_out_put_detail)
+            out_puts = flow_out_put_detail['Outputs']
+            for out_put in out_puts:
+                if out_put["Producer"] == node_dict['二分类评估-1'] and out_put["Name"] == "outputMetricTable":
+                    value1 = json.loads(out_put["Info"]['value'])
+                    table_dict['二分类评估-1'] = value1['location']['table']
+                if out_put["Producer"] == node_dict['二分类评估-2'] and out_put["Name"] == "outputMetricTable":
+                    value2 = json.loads(out_put["Info"]['value'])
+                    table_dict['二分类评估-2'] = value2['location']['table']
+                if out_put["Producer"] == node_dict['预测结果对比'] and out_put["Name"] == "outputTable":
+                    value3 = json.loads(out_put["Info"]['value'])
+                    table_dict['预测结果对比'] = value3['location']['table']
         num = 10
-        df = get_data_from_odps('pai_algo', tabel_dict['预测结果对比'], 10)
+        df = get_data_from_odps('pai_algo', table_dict['预测结果对比'], 10)
         # 对指定列取绝对值再求和
         old_abs_avg = df['old_error'].abs().sum() / num
         new_abs_avg = df['new_error'].abs().sum() / num
-        new_auc = get_dict_from_odps('pai_algo', tabel_dict['二分类评估-1'])['AUC']
-        old_auc = get_dict_from_odps('pai_algo', tabel_dict['二分类评估-2'])['AUC']
+        new_auc = get_dict_from_odps('pai_algo', table_dict['二分类评估-1'])['AUC']
+        old_auc = get_dict_from_odps('pai_algo', table_dict['二分类评估-2'])['AUC']
         bizdate = get_previous_days_date(1)
         score_diff = abs(old_abs_avg - new_abs_avg)
         msg = ""
-        level = ""
+        result = False
         if new_abs_avg > 0.1:
             msg += f'线上模型评估{bizdate}的数据,绝对误差大于0.1,请检查'
             level = 'error'
@@ -557,11 +686,9 @@ def validate_model_data_accuracy():
             msg += f'两个模型评估${bizdate}的数据,两个模型分数差异为: ${score_diff}, 大于0.05, 请检查'
             level = 'error'
         else:
-            # update_online_model()
             msg += 'DNN广告模型更新完成'
             level = 'info'
-        step_end_time = int(time.time())
-        elapsed = step_end_time - start_time
+            result = True
 
         # 初始化表格头部
         top10_msg = "| CID  | 老模型相对真实CTCVR的变化 | 新模型相对真实CTCVR的变化 |"
@@ -578,78 +705,44 @@ def validate_model_data_accuracy():
         msg += f"\n\t - 新模型AUC: {new_auc}"
         msg += f"\n\t - 老模型Top10差异平均值: {old_abs_avg}"
         msg += f"\n\t - 新模型Top10差异平均值: {new_abs_avg}"
-        _monitor(level, msg, start_time, elapsed, top10_msg)
+        return result, msg, level, top10_msg
+
+    except Exception as e:
+        error_message = f"在执行 validate_model_data_accuracy 函数时出现异常: {str(e)}"
+        print(error_message)
+        raise Exception(error_message)
 
 
 if __name__ == '__main__':
     start_time = int(time.time())
-    # 1.更新工作流
-    update_online_new_flow()
-
-    # 2.训练模型
-    train_model()
-
-    # 3. 验证模型数据 & 更新模型到线上
-    # validate_model_data_accuracy()
-    # start_time = int(time.time())
-    # node_dict = get_node_dict()
-    # str = '{"Creator": "204034041838504386", "ExecuteType": "EXECUTE_FROM_HERE", "ExperimentId": "draft-7u3e9v1uc5pohjl0t6", "GmtCreateTime": "2025-04-01T03:17:42.000+00:00", "JobId": "job-8u3ev2uf5ncoexj9p9", "PaiflowNodeId": "node-9wtveoz1tu89tqfoox", "RequestId": "6ED5FFB1-346B-5075-ACC9-029EB77E9F09", "RunId": "flow-lchat027733ttstdc0", "Status": "Succeeded", "WorkspaceId": "96094"}'
-    # validate_job_detail = json.loads(str)
-    # if validate_job_detail['Status'] == 'Succeeded':
-    #     pipeline_run_id = validate_job_detail['RunId']
-    #     node_id = validate_job_detail['PaiflowNodeId']
-    #     flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 3)
-    #     print(flow_out_put_detail)
-    #     tabel_dict = {}
-    #     out_puts = flow_out_put_detail['Outputs']
-    #     for out_put in out_puts:
-    #         if out_put["Producer"] == node_dict['二分类评估-1'] and out_put["Name"] == "outputMetricTable":
-    #             value1 = json.loads(out_put["Info"]['value'])
-    #             tabel_dict['二分类评估-1'] = value1['location']['table']
-    #         if out_put["Producer"] == node_dict['二分类评估-2'] and out_put["Name"] == "outputMetricTable":
-    #             value2 = json.loads(out_put["Info"]['value'])
-    #             tabel_dict['二分类评估-2'] = value2['location']['table']
-    #         if out_put["Producer"] == node_dict['预测结果对比'] and out_put["Name"] == "outputTable":
-    #             value3 = json.loads(out_put["Info"]['value'])
-    #             tabel_dict['预测结果对比'] = value3['location']['table']
-    #
-    #     num = 10
-    #     df = get_data_from_odps('pai_algo', tabel_dict['预测结果对比'], 10)
-    #     # 对指定列取绝对值再求和
-    #     old_abs_avg = df['old_error'].abs().sum() / num
-    #     new_abs_avg = df['new_error'].abs().sum() / num
-    #     new_auc = get_dict_from_odps('pai_algo', tabel_dict['二分类评估-1'])['AUC']
-    #     old_auc = get_dict_from_odps('pai_algo', tabel_dict['二分类评估-2'])['AUC']
-    #     bizdate = get_previous_days_date(1)
-    #     score_diff = abs(old_abs_avg - new_abs_avg)
-    #     msg = ""
-    #     level = ""
-    #     if new_abs_avg > 0.1:
-    #         msg += f'线上模型评估{bizdate}的数据,绝对误差大于0.1,请检查'
-    #         level = 'error'
-    #     elif score_diff > 0.05:
-    #         msg += f'两个模型评估${bizdate}的数据,两个模型分数差异为: ${score_diff}, 大于0.05, 请检查'
-    #         level = 'error'
-    #     else:
-    #         # TODO 更新模型到线上
-    #         msg += 'DNN广告模型更新完成'
-    #         level = 'info'
-    #     step_end_time = int(time.time())
-    #     elapsed = step_end_time - start_time
-    #
-    #     # 初始化表格头部
-    #     top10_msg = "| CID  | 老模型相对真实CTCVR的变化 | 新模型相对真实CTCVR的变化 |"
-    #     top10_msg += "\n| ---- | --------- | -------- |"
-    #
-    #     for index, row in df.iterrows():
-    #         # 获取指定列的元素
-    #         cid = row['cid']
-    #         old_error = row['old_error']
-    #         new_error = row['new_error']
-    #         top10_msg += f"\n| {int(cid)} | {old_error} | {new_error} | "
-    #     print(top10_msg)
-    #     msg += f"\n\t - 老模型AUC: {old_auc}"
-    #     msg += f"\n\t - 新模型AUC: {new_auc}"
-    #     msg += f"\n\t - 老模型Top10差异平均值: {old_abs_avg}"
-    #     msg += f"\n\t - 新模型Top10差异平均值: {new_abs_avg}"
-        # _monitor('info', msg, start_time, elapsed, top10_msg)
+    functions = [update_online_flow, shuffle_table, shuffle_train_model, export_model, get_validate_model_data]
+    function_names = [func.__name__ for func in functions]
+
+    start_function = None
+    if len(sys.argv) > 1:
+        start_function = sys.argv[1]
+        if start_function not in function_names:
+            print(f"指定的起始函数 {start_function} 不存在,请选择以下函数之一:{', '.join(function_names)}")
+            sys.exit(1)
+
+    start_index = 0
+    if start_function:
+        start_index = function_names.index(start_function)
+
+    for func in functions[start_index:]:
+        if not func():
+            print(f"{func.__name__} 执行失败,后续函数不再执行。")
+            step_end_time = int(time.time())
+            elapsed = step_end_time - start_time
+            _monitor('error', f"DNN模型更新,{func.__name__} 执行失败,后续函数不再执行,请检查", start_time, elapsed, None)
+            break
+    else:
+        print("所有函数都成功执行,可以继续下一步操作。")
+        result, msg, level, top10_msg = validate_model_data_accuracy()
+        if result:
+            # update_online_model()
+            print("success")
+        step_end_time = int(time.time())
+        elapsed = step_end_time - start_time
+        print(level, msg, start_time, elapsed, top10_msg)
+        _monitor(level, msg, start_time, elapsed, top10_msg)

+ 0 - 748
ad/pai_flow_operator2.py

@@ -1,748 +0,0 @@
-# -*- coding: utf-8 -*-
-import functools
-import os
-import re
-import sys
-import time
-import json
-import pandas as pd
-from alibabacloud_paistudio20210202.client import Client as PaiStudio20210202Client
-from alibabacloud_tea_openapi import models as open_api_models
-from alibabacloud_paistudio20210202 import models as pai_studio_20210202_models
-from alibabacloud_tea_util import models as util_models
-from alibabacloud_tea_util.client import Client as UtilClient
-from alibabacloud_eas20210701.client import Client as eas20210701Client
-from alibabacloud_paiflow20210202 import models as paiflow_20210202_models
-from alibabacloud_paiflow20210202.client import Client as PAIFlow20210202Client
-from datetime import datetime, timedelta
-from odps import ODPS
-from ad_monitor_util import _monitor
-
-target_names = {
-    '样本shuffle',
-    '模型训练-样本shufle',
-    '模型训练-自定义',
-    '模型增量训练',
-    '模型导出-2',
-    '更新EAS服务(Beta)-1',
-    '虚拟起始节点',
-    '二分类评估-1',
-    '二分类评估-2',
-    '预测结果对比'
-}
-
-experiment_id = "draft-wqgkag89sbh9v1zvut"
-
-MAX_RETRIES = 3
-
-
-def retry(func):
-    @functools.wraps(func)
-    def wrapper(*args, **kwargs):
-        retries = 0
-        while retries < MAX_RETRIES:
-            try:
-                result = func(*args, **kwargs)
-                if result is not False:
-                    return result
-            except Exception as e:
-                print(f"函数 {func.__name__} 执行时发生异常: {e},重试第 {retries + 1} 次")
-            retries += 1
-        print(f"函数 {func.__name__} 重试 {MAX_RETRIES} 次后仍失败。")
-        return False
-
-    return wrapper
-
-
-def get_odps_instance(project):
-    odps = ODPS(
-        access_id='LTAIWYUujJAm7CbH',
-        secret_access_key='RfSjdiWwED1sGFlsjXv0DlfTnZTG1P',
-        project=project,
-        endpoint='http://service.cn.maxcompute.aliyun.com/api',
-    )
-    return odps
-
-
-def get_data_from_odps(project, table, num):
-    odps = get_odps_instance(project)
-    try:
-        # 要查询的 SQL 语句
-        sql = f'select * from {table} limit {num}'
-        # 执行 SQL 查询
-        with odps.execute_sql(sql).open_reader() as reader:
-            df = reader.to_pandas()
-            # 查询数量小于目标数量时 返回空
-            if len(df) < num:
-                return None
-            return df
-    except Exception as e:
-        print(f"发生错误: {e}")
-
-
-def get_dict_from_odps(project, table):
-    odps = get_odps_instance(project)
-    try:
-        # 要查询的 SQL 语句
-        sql = f'select * from {table}'
-        # 执行 SQL 查询
-        with odps.execute_sql(sql).open_reader() as reader:
-            data = {}
-            for record in reader:
-                record_list = list(record)
-                key = record_list[0][1]
-                value = record_list[1][1]
-                data[key] = value
-            return data
-    except Exception as e:
-        print(f"发生错误: {e}")
-
-
-def get_dates_between(start_date_str, end_date_str):
-    start_date = datetime.strptime(start_date_str, '%Y%m%d')
-    end_date = datetime.strptime(end_date_str, '%Y%m%d')
-    dates = []
-    current_date = start_date
-    while current_date <= end_date:
-        dates.append(current_date.strftime('%Y%m%d'))
-        current_date += timedelta(days=1)
-    return dates
-
-
-def read_file_to_list():
-    try:
-        current_dir = os.getcwd()
-        file_path = os.path.join(current_dir, 'ad', 'holidays.txt')
-        with open(file_path, 'r', encoding='utf-8') as file:
-            content = file.read()
-            return content.split('\n')
-    except FileNotFoundError:
-        raise Exception(f"错误:未找到 {file_path} 文件。")
-    except Exception as e:
-        raise Exception(f"错误:发生了一个未知错误: {e}")
-    return []
-
-
-def get_previous_days_date(days):
-    current_date = datetime.now()
-    previous_date = current_date - timedelta(days=days)
-    return previous_date.strftime('%Y%m%d')
-
-
-def remove_elements(lst1, lst2):
-    return [element for element in lst1 if element not in lst2]
-
-
-def process_list(lst, append_str):
-    # 给列表中每个元素拼接相同的字符串
-    appended_list = [append_str + element for element in lst]
-    # 将拼接后的列表元素用逗号拼接成一个字符串
-    result_str = ','.join(appended_list)
-    return result_str
-
-
-def get_train_data_list():
-    start_date = '20250320'
-    end_date = get_previous_days_date(2)
-    date_list = get_dates_between(start_date, end_date)
-    filter_date_list = read_file_to_list()
-    date_list = remove_elements(date_list, filter_date_list)
-    return date_list
-
-
-def update_train_tables(old_str):
-    date_list = get_train_data_list()
-    train_list = ["'" + item + "'" for item in date_list]
-    result = ','.join(train_list)
-    start_index = old_str.find('where dt in (')
-    if start_index != -1:
-        equal_sign_index = start_index + len('where dt in (')
-        # 找到下一个双引号的位置
-        next_quote_index = old_str.find(')', equal_sign_index)
-        if next_quote_index != -1:
-            # 进行替换
-            new_value = old_str[:equal_sign_index] + result + old_str[next_quote_index:]
-            return new_value
-    return None
-
-
-def compare_timestamp_with_today_start(time_str):
-    # 解析时间字符串为 datetime 对象
-    time_obj = datetime.fromisoformat(time_str)
-    # 将其转换为时间戳
-    target_timestamp = time_obj.timestamp()
-    # 获取今天开始的时间
-    today_start = datetime.combine(datetime.now().date(), datetime.min.time())
-    # 将今天开始时间转换为时间戳
-    today_start_timestamp = today_start.timestamp()
-    return target_timestamp > today_start_timestamp
-
-
-def update_train_table(old_str, table):
-    address = 'odps://pai_algo/tables/'
-    train_table = address + table
-    start_index = old_str.find('-Dtrain_tables="')
-    if start_index != -1:
-        # 确定等号的位置
-        equal_sign_index = start_index + len('-Dtrain_tables="')
-        # 找到下一个双引号的位置
-        next_quote_index = old_str.find('"', equal_sign_index)
-        if next_quote_index != -1:
-            # 进行替换
-            new_value = old_str[:equal_sign_index] + train_table + old_str[next_quote_index:]
-            return new_value
-    return None
-
-
-class PAIClient:
-    def __init__(self):
-        pass
-
-    @staticmethod
-    def create_client() -> PaiStudio20210202Client:
-        """
-        使用AK&SK初始化账号Client
-        @return: Client
-        @throws Exception
-        """
-        # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
-        # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
-        config = open_api_models.Config(
-            access_key_id="LTAI5tFGqgC8f3mh1fRCrAEy",
-            access_key_secret="XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
-        )
-        # Endpoint 请参考 https://api.aliyun.com/product/PaiStudio
-        config.endpoint = f'pai.cn-hangzhou.aliyuncs.com'
-        return PaiStudio20210202Client(config)
-
-    @staticmethod
-    def create_eas_client() -> eas20210701Client:
-        """
-        使用AK&SK初始化账号Client
-        @return: Client
-        @throws Exception
-        """
-        # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
-        # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
-        config = open_api_models.Config(
-            access_key_id="LTAI5tFGqgC8f3mh1fRCrAEy",
-            access_key_secret="XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
-        )
-        # Endpoint 请参考 https://api.aliyun.com/product/PaiStudio
-        config.endpoint = f'pai-eas.cn-hangzhou.aliyuncs.com'
-        return eas20210701Client(config)
-
-    @staticmethod
-    def create_flow_client() -> PAIFlow20210202Client:
-        """
-        使用AK&SK初始化账号Client
-        @return: Client
-        @throws Exception
-        """
-        # 工程代码泄露可能会导致 AccessKey 泄露,并威胁账号下所有资源的安全性。以下代码示例仅供参考。
-        # 建议使用更安全的 STS 方式,更多鉴权访问方式请参见:https://help.aliyun.com/document_detail/378659.html。
-        config = open_api_models.Config(
-            # 必填,请确保代码运行环境设置了环境变量 ALIBABA_CLOUD_ACCESS_KEY_ID。,
-            access_key_id="LTAI5tFGqgC8f3mh1fRCrAEy",
-            # 必填,请确保代码运行环境设置了环境变量 ALIBABA_CLOUD_ACCESS_KEY_SECRET。,
-            access_key_secret="XhOjK9XmTYRhVAtf6yii4s4kZwWzvV"
-        )
-        # Endpoint 请参考 https://api.aliyun.com/product/PAIFlow
-        config.endpoint = f'paiflow.cn-hangzhou.aliyuncs.com'
-        return PAIFlow20210202Client(config)
-
-    @staticmethod
-    def get_work_flow_draft_list(workspace_id: str):
-        client = PAIClient.create_client()
-        list_experiments_request = pai_studio_20210202_models.ListExperimentsRequest(
-            workspace_id=workspace_id
-        )
-        runtime = util_models.RuntimeOptions()
-        headers = {}
-        try:
-            resp = client.list_experiments_with_options(list_experiments_request, headers, runtime)
-            return resp.body.to_map()
-        except Exception as error:
-            raise Exception(f"get_work_flow_draft_list error {error}")
-
-    @staticmethod
-    def get_work_flow_draft(experiment_id: str):
-        client = PAIClient.create_client()
-        runtime = util_models.RuntimeOptions()
-        headers = {}
-        try:
-            # 复制代码运行请自行打印 API 的返回值
-            resp = client.get_experiment_with_options(experiment_id, headers, runtime)
-            return resp.body.to_map()
-        except Exception as error:
-            raise Exception(f"get_work_flow_draft error {error}")
-
-    @staticmethod
-    def get_describe_service(service_name: str):
-        client = PAIClient.create_eas_client()
-        runtime = util_models.RuntimeOptions()
-        headers = {}
-        try:
-            # 复制代码运行请自行打印 API 的返回值
-            resp = client.describe_service_with_options('cn-hangzhou', service_name, headers, runtime)
-            return resp.body.to_map()
-        except Exception as error:
-            raise Exception(f"get_describe_service error {error}")
-
-    @staticmethod
-    def update_experiment_content(experiment_id: str, content: str, version: int):
-        client = PAIClient.create_client()
-        update_experiment_content_request = pai_studio_20210202_models.UpdateExperimentContentRequest(content=content,
-                                                                                                      version=version)
-        runtime = util_models.RuntimeOptions()
-        headers = {}
-        try:
-            # 复制代码运行请自行打印 API 的返回值
-            resp = client.update_experiment_content_with_options(experiment_id, update_experiment_content_request,
-                                                                 headers, runtime)
-            print(resp.body.to_map())
-        except Exception as error:
-            raise Exception(f"update_experiment_content error {error}")
-
-    @staticmethod
-    def create_job(experiment_id: str, node_id: str, execute_type: str):
-        client = PAIClient.create_client()
-        create_job_request = pai_studio_20210202_models.CreateJobRequest()
-        create_job_request.experiment_id = experiment_id
-        create_job_request.node_id = node_id
-        create_job_request.execute_type = execute_type
-        runtime = util_models.RuntimeOptions()
-        headers = {}
-        try:
-            # 复制代码运行请自行打印 API 的返回值
-            resp = client.create_job_with_options(create_job_request, headers, runtime)
-            return resp.body.to_map()
-        except Exception as error:
-            raise Exception(f"create_job error {error}")
-
-    @staticmethod
-    def get_jobs_list(experiment_id: str, order='DESC'):
-        client = PAIClient.create_client()
-        list_jobs_request = pai_studio_20210202_models.ListJobsRequest(
-            experiment_id=experiment_id,
-            order=order
-        )
-        runtime = util_models.RuntimeOptions()
-        headers = {}
-        try:
-            # 复制代码运行请自行打印 API 的返回值
-            resp = client.list_jobs_with_options(list_jobs_request, headers, runtime)
-            return resp.body.to_map()
-        except Exception as error:
-            raise Exception(f"get_jobs_list error {error}")
-
-    @staticmethod
-    def get_job_detail(job_id: str, verbose=False):
-        client = PAIClient.create_client()
-        get_job_request = pai_studio_20210202_models.GetJobRequest(
-            verbose=verbose
-        )
-        runtime = util_models.RuntimeOptions()
-        headers = {}
-        try:
-            # 复制代码运行请自行打印 API 的返回值
-            resp = client.get_job_with_options(job_id, get_job_request, headers, runtime)
-            return resp.body.to_map()
-        except Exception as error:
-            # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
-            # 错误 message
-            print(error.message)
-            # 诊断地址
-            print(error.data.get("Recommend"))
-            UtilClient.assert_as_string(error.message)
-
-    @staticmethod
-    def get_flow_out_put(pipeline_run_id: str, node_id: str, depth: int):
-        client = PAIClient.create_flow_client()
-        list_pipeline_run_node_outputs_request = paiflow_20210202_models.ListPipelineRunNodeOutputsRequest(
-            depth=depth
-        )
-        runtime = util_models.RuntimeOptions()
-        headers = {}
-        try:
-            # 复制代码运行请自行打印 API 的返回值
-            resp = client.list_pipeline_run_node_outputs_with_options(pipeline_run_id, node_id,
-                                                                      list_pipeline_run_node_outputs_request, headers,
-                                                                      runtime)
-            return resp.body.to_map()
-        except Exception as error:
-            # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
-            # 错误 message
-            print(error.message)
-            # 诊断地址
-            print(error.data.get("Recommend"))
-            UtilClient.assert_as_string(error.message)
-
-
-def extract_date_yyyymmdd(input_string):
-    pattern = r'\d{8}'
-    matches = re.findall(pattern, input_string)
-    if matches:
-        return matches[0]
-    return None
-
-
-def get_online_model_config(service_name: str):
-    model_config = {}
-    model_detail = PAIClient.get_describe_service(service_name)
-    service_config_str = model_detail['ServiceConfig']
-    service_config = json.loads(service_config_str)
-    model_path = service_config['model_path']
-    model_config['model_path'] = model_path
-    online_date = extract_date_yyyymmdd(model_path)
-    model_config['online_date'] = online_date
-    return model_config
-
-
-
-def update_shuffle_flow(table):
-    draft = PAIClient.get_work_flow_draft(experiment_id)
-    print(json.dumps(draft, ensure_ascii=False))
-    content = draft['Content']
-    version = draft['Version']
-    content_json = json.loads(content)
-    nodes = content_json.get('nodes')
-    for node in nodes:
-        name = node['name']
-        if name == '模型训练-样本shufle':
-            properties = node['properties']
-            for property in properties:
-                if property['name'] == 'sql':
-                    value = property['value']
-                    new_value = update_train_table(value, table)
-                    if new_value is None:
-                        print("error")
-                    property['value'] = new_value
-    new_content = json.dumps(content_json, ensure_ascii=False)
-    PAIClient.update_experiment_content(experiment_id, new_content, version)
-
-
-def update_shuffle_flow_1():
-    draft = PAIClient.get_work_flow_draft(experiment_id)
-    print(json.dumps(draft, ensure_ascii=False))
-    content = draft['Content']
-    version = draft['Version']
-    print(content)
-    content_json = json.loads(content)
-    nodes = content_json.get('nodes')
-    for node in nodes:
-        name = node['name']
-        if name == '模型训练-样本shufle':
-            properties = node['properties']
-            for property in properties:
-                if property['name'] == 'sql':
-                    value = property['value']
-                    new_value = update_train_tables(value)
-                    if new_value is None:
-                        print("error")
-                    property['value'] = new_value
-    new_content = json.dumps(content_json, ensure_ascii=False)
-    PAIClient.update_experiment_content(experiment_id, new_content, version)
-
-
-def wait_job_end(job_id: str):
-    while True:
-        job_detail = PAIClient.get_job_detail(job_id)
-        print(job_detail)
-        statue = job_detail['Status']
-        # Initialized: 初始化完成 Starting:开始 WorkflowServiceStarting:准备提交 Running:运行中 ReadyToSchedule:准备运行(前序节点未完成导致)
-        if (statue == 'Initialized' or statue == 'Starting' or statue == 'WorkflowServiceStarting'
-                or statue == 'Running' or statue == 'ReadyToSchedule'):
-            # 睡眠300s 等待下次获取
-            time.sleep(300)
-            continue
-        # Failed:运行失败 Terminating:终止中 Terminated:已终止 Unknown:未知 Skipped:跳过(前序节点失败导致) Succeeded:运行成功
-        if statue == 'Failed' or statue == 'Terminating' or statue == 'Unknown' or statue == 'Skipped' or statue == 'Succeeded':
-            return job_detail
-
-
-def get_node_dict():
-    draft = PAIClient.get_work_flow_draft(experiment_id)
-    content = draft['Content']
-    content_json = json.loads(content)
-    nodes = content_json.get('nodes')
-    node_dict = {}
-    for node in nodes:
-        name = node['name']
-        # 检查名称是否在目标名称集合中
-        if name in target_names:
-            node_dict[name] = node['id']
-    return node_dict
-
-
-def get_job_dict():
-    job_dict = {}
-    jobs_list = PAIClient.get_jobs_list(experiment_id)
-    for job in jobs_list['Jobs']:
-        # 解析时间字符串为 datetime 对象
-        if not compare_timestamp_with_today_start(job['GmtCreateTime']):
-            break
-        job_id = job['JobId']
-        job_detail = PAIClient.get_job_detail(job_id, verbose=True)
-        for name in target_names:
-            if job_detail['Status'] != 'Succeeded':
-                continue
-            if name in job_dict:
-                continue
-            if name in job_detail['RunInfo']:
-                job_dict[name] = job_detail['JobId']
-    return job_dict
-
-@retry
-def update_online_flow():
-    try:
-        online_model_config = get_online_model_config('ad_rank_dnn_v11_easyrec')
-        draft = PAIClient.get_work_flow_draft(experiment_id)
-        print(json.dumps(draft, ensure_ascii=False))
-        content = draft['Content']
-        version = draft['Version']
-        print(content)
-        content_json = json.loads(content)
-        nodes = content_json.get('nodes')
-        global_params = content_json.get('globalParams')
-        bizdate = get_previous_days_date(1)
-        for global_param in global_params:
-            try:
-                if global_param['name'] == 'bizdate':
-                    global_param['value'] = bizdate
-                if global_param['name'] == 'online_version_dt':
-                    global_param['value'] = online_model_config['online_date']
-                if global_param['name'] == 'eval_date':
-                    global_param['value'] = bizdate
-                if global_param['name'] == 'online_model_path':
-                    global_param['value'] = online_model_config['model_path']
-            except KeyError:
-                raise Exception("在处理全局参数时,字典中缺少必要的键")
-        for node in nodes:
-            try:
-                name = node['name']
-                if name == '样本shuffle':
-                    properties = node['properties']
-                    for property in properties:
-                        if property['name'] == 'sql':
-                            value = property['value']
-                            new_value = update_train_tables(value)
-                            if new_value is None:
-                                print("error")
-                            property['value'] = new_value
-            except KeyError:
-                raise Exception("在处理节点属性时,字典中缺少必要的键")
-        new_content = json.dumps(content_json, ensure_ascii=False)
-        PAIClient.update_experiment_content(experiment_id, new_content, version)
-        return True
-    except json.JSONDecodeError:
-        raise Exception("JSON 解析错误,可能是草稿内容格式不正确")
-    except Exception as e:
-        raise Exception(f"发生未知错误: {e}")
-
-@retry
-def shuffle_table():
-    try:
-        node_dict = get_node_dict()
-        train_node_id = node_dict['样本shuffle']
-        execute_type = 'EXECUTE_ONE'
-        validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
-        validate_job_id = validate_res['JobId']
-        validate_job_detail = wait_job_end(validate_job_id)
-        if validate_job_detail['Status'] == 'Succeeded':
-            return True
-        return False
-    except Exception as e:
-        error_message = f"在执行 shuffle_table 函数时发生异常: {str(e)}"
-        print(error_message)
-        raise Exception(error_message)
-
-
-@retry
-def shuffle_train_model():
-    try:
-        node_dict = get_node_dict()
-        job_dict = get_job_dict()
-        job_id = job_dict['样本shuffle']
-        validate_job_detail = wait_job_end(job_id)
-        if validate_job_detail['Status'] == 'Succeeded':
-            pipeline_run_id = validate_job_detail['RunId']
-            node_id = validate_job_detail['PaiflowNodeId']
-            flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 2)
-            out_puts = flow_out_put_detail['Outputs']
-            table = None
-            for out_put in out_puts:
-                if out_put["Producer"] == node_dict['样本shuffle'] and out_put["Name"] == "outputTable":
-                    value1 = json.loads(out_put["Info"]['value'])
-                    table = value1['location']['table']
-            if table is not None:
-                update_shuffle_flow(table)
-                node_dict = get_node_dict()
-                train_node_id = node_dict['模型训练-样本shufle']
-                execute_type = 'EXECUTE_ONE'
-                train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
-                train_job_id = train_res['JobId']
-                train_job_detail = wait_job_end(train_job_id)
-                if train_job_detail['Status'] == 'Succeeded':
-                    return True
-        return False
-    except Exception as e:
-        error_message = f"在执行 shuffle_train_model 函数时发生异常: {str(e)}"
-        print(error_message)
-        raise Exception(error_message)
-
-
-@retry
-def export_model():
-    try:
-        node_dict = get_node_dict()
-        export_node_id = node_dict['模型导出-2']
-        execute_type = 'EXECUTE_ONE'
-        export_res = PAIClient.create_job(experiment_id, export_node_id, execute_type)
-        export_job_id = export_res['JobId']
-        export_job_detail = wait_job_end(export_job_id)
-        if export_job_detail['Status'] == 'Succeeded':
-            return True
-        return False
-    except Exception as e:
-        error_message = f"在执行 export_model 函数时发生异常: {str(e)}"
-        print(error_message)
-        raise Exception(error_message)
-
-
-def update_online_model():
-    try:
-        node_dict = get_node_dict()
-        train_node_id = node_dict['更新EAS服务(Beta)-1']
-        execute_type = 'EXECUTE_ONE'
-        train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
-        train_job_id = train_res['JobId']
-        train_job_detail = wait_job_end(train_job_id)
-        if train_job_detail['Status'] == 'Succeeded':
-            return True
-        return False
-    except Exception as e:
-        error_message = f"在执行 update_online_model 函数时发生异常: {str(e)}"
-        print(error_message)
-        raise Exception(error_message)
-
-
-@retry
-def get_validate_model_data():
-    try:
-        node_dict = get_node_dict()
-        train_node_id = node_dict['虚拟起始节点']
-        execute_type = 'EXECUTE_FROM_HERE'
-        validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
-        validate_job_id = validate_res['JobId']
-        validate_job_detail = wait_job_end(validate_job_id)
-        if validate_job_detail['Status'] == 'Succeeded':
-            return True
-        return False
-    except Exception as e:
-        error_message = f"在执行 get_validate_model_data 函数时出现异常: {e}"
-        print(error_message)
-        raise Exception(error_message)
-
-
-def validate_model_data_accuracy():
-    try:
-        table_dict = {}
-        node_dict = get_node_dict()
-        job_dict = get_job_dict()
-        job_id = job_dict['虚拟起始节点']
-        validate_job_detail = wait_job_end(job_id)
-        if validate_job_detail['Status'] == 'Succeeded':
-            pipeline_run_id = validate_job_detail['RunId']
-            node_id = validate_job_detail['PaiflowNodeId']
-            flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 3)
-            print(flow_out_put_detail)
-            out_puts = flow_out_put_detail['Outputs']
-            for out_put in out_puts:
-                if out_put["Producer"] == node_dict['二分类评估-1'] and out_put["Name"] == "outputMetricTable":
-                    value1 = json.loads(out_put["Info"]['value'])
-                    table_dict['二分类评估-1'] = value1['location']['table']
-                if out_put["Producer"] == node_dict['二分类评估-2'] and out_put["Name"] == "outputMetricTable":
-                    value2 = json.loads(out_put["Info"]['value'])
-                    table_dict['二分类评估-2'] = value2['location']['table']
-                if out_put["Producer"] == node_dict['预测结果对比'] and out_put["Name"] == "outputTable":
-                    value3 = json.loads(out_put["Info"]['value'])
-                    table_dict['预测结果对比'] = value3['location']['table']
-        num = 10
-        df = get_data_from_odps('pai_algo', table_dict['预测结果对比'], 10)
-        # 对指定列取绝对值再求和
-        old_abs_avg = df['old_error'].abs().sum() / num
-        new_abs_avg = df['new_error'].abs().sum() / num
-        new_auc = get_dict_from_odps('pai_algo', table_dict['二分类评估-1'])['AUC']
-        old_auc = get_dict_from_odps('pai_algo', table_dict['二分类评估-2'])['AUC']
-        bizdate = get_previous_days_date(1)
-        score_diff = abs(old_abs_avg - new_abs_avg)
-        msg = ""
-        result = False
-        if new_abs_avg > 0.1:
-            msg += f'线上模型评估{bizdate}的数据,绝对误差大于0.1,请检查'
-            level = 'error'
-        elif score_diff > 0.05:
-            msg += f'两个模型评估${bizdate}的数据,两个模型分数差异为: ${score_diff}, 大于0.05, 请检查'
-            level = 'error'
-        else:
-            msg += 'DNN广告模型更新完成'
-            level = 'info'
-            result = True
-
-        # 初始化表格头部
-        top10_msg = "| CID  | 老模型相对真实CTCVR的变化 | 新模型相对真实CTCVR的变化 |"
-        top10_msg += "\n| ---- | --------- | -------- |"
-
-        for index, row in df.iterrows():
-            # 获取指定列的元素
-            cid = row['cid']
-            old_error = row['old_error']
-            new_error = row['new_error']
-            top10_msg += f"\n| {int(cid)} | {old_error} | {new_error} | "
-        print(top10_msg)
-        msg += f"\n\t - 老模型AUC: {old_auc}"
-        msg += f"\n\t - 新模型AUC: {new_auc}"
-        msg += f"\n\t - 老模型Top10差异平均值: {old_abs_avg}"
-        msg += f"\n\t - 新模型Top10差异平均值: {new_abs_avg}"
-        return result, msg, level, top10_msg
-
-    except Exception as e:
-        error_message = f"在执行 validate_model_data_accuracy 函数时出现异常: {str(e)}"
-        print(error_message)
-        raise Exception(error_message)
-
-
-if __name__ == '__main__':
-    start_time = int(time.time())
-    functions = [update_online_flow, shuffle_table, shuffle_train_model, export_model, get_validate_model_data]
-    function_names = [func.__name__ for func in functions]
-
-    start_function = None
-    if len(sys.argv) > 1:
-        start_function = sys.argv[1]
-        if start_function not in function_names:
-            print(f"指定的起始函数 {start_function} 不存在,请选择以下函数之一:{', '.join(function_names)}")
-            sys.exit(1)
-
-    start_index = 0
-    if start_function:
-        start_index = function_names.index(start_function)
-
-    for func in functions[start_index:]:
-        if not func():
-            print(f"{func.__name__} 执行失败,后续函数不再执行。")
-            step_end_time = int(time.time())
-            elapsed = step_end_time - start_time
-            _monitor('error', f"DNN模型更新,{func.__name__} 执行失败,后续函数不再执行,请检查", start_time, elapsed, None)
-            break
-    else:
-        print("所有函数都成功执行,可以继续下一步操作。")
-        result, msg, level, top10_msg = validate_model_data_accuracy()
-        if result:
-            # update_online_model()
-            print("success")
-        step_end_time = int(time.time())
-        elapsed = step_end_time - start_time
-        print(level, msg, start_time, elapsed, top10_msg)
-        _monitor(level, msg, start_time, elapsed, top10_msg)