Browse Source

优化pai自动更新脚本

xueyiming 2 weeks ago
parent
commit
fa7ea7145d
1 changed files with 258 additions and 139 deletions
  1. 258 139
      ad/pai_flow_operator2.py

+ 258 - 139
ad/pai_flow_operator2.py

@@ -31,6 +31,25 @@ target_names = {
 
 experiment_id = "draft-wqgkag89sbh9v1zvut"
 
+MAX_RETRIES = 3
+
+
+def retry(func):
+    def wrapper(*args, **kwargs):
+        retries = 0
+        while retries < MAX_RETRIES:
+            try:
+                result = func(*args, **kwargs)
+                if result is not False:
+                    return result
+            except Exception as e:
+                print(f"函数 {func.__name__} 执行时发生异常: {e},重试第 {retries + 1} 次")
+            retries += 1
+        print(f"函数 {func.__name__} 重试 {MAX_RETRIES} 次后仍失败。")
+        return False
+
+    return wrapper
+
 
 def get_odps_instance(project):
     odps = ODPS(
@@ -144,6 +163,18 @@ def update_train_tables(old_str):
     return None
 
 
+def compare_timestamp_with_today_start(time_str):
+    # 解析时间字符串为 datetime 对象
+    time_obj = datetime.fromisoformat(time_str)
+    # 将其转换为时间戳
+    target_timestamp = time_obj.timestamp()
+    # 获取今天开始的时间
+    today_start = datetime.combine(datetime.now().date(), datetime.min.time())
+    # 将今天开始时间转换为时间戳
+    today_start_timestamp = today_start.timestamp()
+    return target_timestamp > today_start_timestamp
+
+
 def update_train_table(old_str, table):
     address = 'odps://pai_algo/tables/'
     train_table = address + table
@@ -229,9 +260,7 @@ class PAIClient:
             resp = client.list_experiments_with_options(list_experiments_request, headers, runtime)
             return resp.body.to_map()
         except Exception as error:
-            print(error.message)
-            print(error.data.get("Recommend"))
-            UtilClient.assert_as_string(error.message)
+            raise Exception(f"get_work_flow_draft_list error {error}")
 
     @staticmethod
     def get_work_flow_draft(experiment_id: str):
@@ -243,12 +272,7 @@ class PAIClient:
             resp = client.get_experiment_with_options(experiment_id, headers, runtime)
             return resp.body.to_map()
         except Exception as error:
-            # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
-            # 错误 message
-            print(error.message)
-            # 诊断地址
-            print(error.data.get("Recommend"))
-            UtilClient.assert_as_string(error.message)
+            raise Exception(f"get_work_flow_draft error {error}")
 
     @staticmethod
     def get_describe_service(service_name: str):
@@ -260,12 +284,7 @@ class PAIClient:
             resp = client.describe_service_with_options('cn-hangzhou', service_name, headers, runtime)
             return resp.body.to_map()
         except Exception as error:
-            # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
-            # 错误 message
-            print(error.message)
-            # 诊断地址
-            print(error.data.get("Recommend"))
-            UtilClient.assert_as_string(error.message)
+            raise Exception(f"get_describe_service error {error}")
 
     @staticmethod
     def update_experiment_content(experiment_id: str, content: str, version: int):
@@ -280,12 +299,7 @@ class PAIClient:
                                                                  headers, runtime)
             print(resp.body.to_map())
         except Exception as error:
-            # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
-            # 错误 message
-            print(error.message)
-            # 诊断地址
-            print(error.data.get("Recommend"))
-            UtilClient.assert_as_string(error.message)
+            raise Exception(f"update_experiment_content error {error}")
 
     @staticmethod
     def create_job(experiment_id: str, node_id: str, execute_type: str):
@@ -301,18 +315,29 @@ class PAIClient:
             resp = client.create_job_with_options(create_job_request, headers, runtime)
             return resp.body.to_map()
         except Exception as error:
-            # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
-            # 错误 message
-            print(error.message)
-            # 诊断地址
-            print(error.data.get("Recommend"))
-            UtilClient.assert_as_string(error.message)
+            raise Exception(f"create_job error {error}")
+
+    @staticmethod
+    def get_jobs_list(experiment_id: str, order='DESC'):
+        client = PAIClient.create_client()
+        list_jobs_request = pai_studio_20210202_models.ListJobsRequest(
+            experiment_id=experiment_id,
+            order=order
+        )
+        runtime = util_models.RuntimeOptions()
+        headers = {}
+        try:
+            # 复制代码运行请自行打印 API 的返回值
+            resp = client.list_jobs_with_options(list_jobs_request, headers, runtime)
+            return resp.body.to_map()
+        except Exception as error:
+            raise Exception(f"get_jobs_list error {error}")
 
     @staticmethod
-    def get_job_detail(job_id: str):
+    def get_job_detail(job_id: str, verbose=False):
         client = PAIClient.create_client()
         get_job_request = pai_studio_20210202_models.GetJobRequest(
-            verbose=False
+            verbose=verbose
         )
         runtime = util_models.RuntimeOptions()
         headers = {}
@@ -369,36 +394,47 @@ def get_online_version_dt(service_name: str):
 
 
 def update_online_flow():
-    online_version_dt = get_online_version_dt('ad_rank_dnn_v11_easyrec')
-    draft = PAIClient.get_work_flow_draft(experiment_id)
-    print(json.dumps(draft, ensure_ascii=False))
-    content = draft['Content']
-    version = draft['Version']
-    print(content)
-    content_json = json.loads(content)
-    nodes = content_json.get('nodes')
-    global_params = content_json.get('globalParams')
-    bizdate = get_previous_days_date(1)
-    for global_param in global_params:
-        if global_param['name'] == 'bizdate':
-            global_param['value'] = bizdate
-        if global_param['name'] == 'online_version_dt':
-            global_param['value'] = online_version_dt
-        if global_param['name'] == 'eval_date':
-            global_param['value'] = bizdate
-    for node in nodes:
-        name = node['name']
-        if name == '样本shuffle':
-            properties = node['properties']
-            for property in properties:
-                if property['name'] == 'sql':
-                    value = property['value']
-                    new_value = update_train_tables(value)
-                    if new_value is None:
-                        print("error")
-                    property['value'] = new_value
-    new_content = json.dumps(content_json, ensure_ascii=False)
-    PAIClient.update_experiment_content(experiment_id, new_content, version)
+    try:
+        online_version_dt = get_online_version_dt('ad_rank_dnn_v11_easyrec')
+        draft = PAIClient.get_work_flow_draft(experiment_id)
+        print(json.dumps(draft, ensure_ascii=False))
+        content = draft['Content']
+        version = draft['Version']
+        print(content)
+        content_json = json.loads(content)
+        nodes = content_json.get('nodes')
+        global_params = content_json.get('globalParams')
+        bizdate = get_previous_days_date(1)
+        for global_param in global_params:
+            try:
+                if global_param['name'] == 'bizdate':
+                    global_param['value'] = bizdate
+                if global_param['name'] == 'online_version_dt':
+                    global_param['value'] = online_version_dt
+                if global_param['name'] == 'eval_date':
+                    global_param['value'] = bizdate
+            except KeyError:
+                raise Exception("在处理全局参数时,字典中缺少必要的键")
+        for node in nodes:
+            try:
+                name = node['name']
+                if name == '样本shuffle':
+                    properties = node['properties']
+                    for property in properties:
+                        if property['name'] == 'sql':
+                            value = property['value']
+                            new_value = update_train_tables(value)
+                            if new_value is None:
+                                print("error")
+                            property['value'] = new_value
+            except KeyError:
+                raise Exception("在处理节点属性时,字典中缺少必要的键")
+        new_content = json.dumps(content_json, ensure_ascii=False)
+        PAIClient.update_experiment_content(experiment_id, new_content, version)
+    except json.JSONDecodeError:
+        raise Exception("JSON 解析错误,可能是草稿内容格式不正确")
+    except Exception as e:
+        raise Exception(f"发生未知错误: {e}")
 
 
 def update_shuffle_flow(table):
@@ -476,78 +512,152 @@ def get_node_dict():
     return node_dict
 
 
-def train_model():
-    node_dict = get_node_dict()
-    train_node_id = node_dict['样本shuffle']
-    execute_type = 'EXECUTE_ONE'
-    validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
-    validate_job_id = validate_res['JobId']
-    validate_job_detail = wait_job_end(validate_job_id)
-    if validate_job_detail['Status'] == 'Succeeded':
-        pipeline_run_id = validate_job_detail['RunId']
-        node_id = validate_job_detail['PaiflowNodeId']
-        flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 2)
-        out_puts = flow_out_put_detail['Outputs']
-        table = None
-        for out_put in out_puts:
-            if out_put["Producer"] == node_dict['样本shuffle'] and out_put["Name"] == "outputTable":
-                value1 = json.loads(out_put["Info"]['value'])
-                table = value1['location']['table']
-        if table is not None:
-            update_shuffle_flow(table)
-            node_dict = get_node_dict()
-            train_node_id = node_dict['模型训练-样本shufle']
-            execute_type = 'EXECUTE_ONE'
-            train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
-            train_job_id = train_res['JobId']
-            train_job_detail = wait_job_end(train_job_id)
-            if train_job_detail['Status'] == 'Succeeded':
-                export_node_id = node_dict['模型导出-2']
-                export_res = PAIClient.create_job(experiment_id, export_node_id, execute_type)
-                export_job_id = export_res['JobId']
-                export_job_detail = wait_job_end(export_job_id)
-                if export_job_detail['Status'] == 'Succeeded':
+def get_job_dict():
+    job_dict = {}
+    jobs_list = PAIClient.get_jobs_list(experiment_id)
+    for job in jobs_list['Jobs']:
+        # 解析时间字符串为 datetime 对象
+        if not compare_timestamp_with_today_start(job['GmtCreateTime']):
+            break
+        job_id = job['JobId']
+        job_detail = PAIClient.get_job_detail(job_id, verbose=True)
+        for name in target_names:
+            if job_detail['Status'] != 'Succeeded':
+                continue
+            if name in job_dict:
+                continue
+            if name in job_detail['RunInfo']:
+                job_dict[name] = job_detail['JobId']
+    return job_dict
+
+
+@retry
+def shuffle_table():
+    try:
+        node_dict = get_node_dict()
+        train_node_id = node_dict['样本shuffle']
+        execute_type = 'EXECUTE_ONE'
+        validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
+        validate_job_id = validate_res['JobId']
+        validate_job_detail = wait_job_end(validate_job_id)
+        if validate_job_detail['Status'] == 'Succeeded':
+            return True
+        return False
+    except Exception as e:
+        error_message = f"在执行 shuffle_table 函数时发生异常: {str(e)}"
+        print(error_message)
+        raise Exception(error_message)
+
+
+@retry
+def shuffle_train_model():
+    try:
+        node_dict = get_node_dict()
+        job_dict = get_job_dict()
+        job_id = job_dict['样本shuffle']
+        validate_job_detail = wait_job_end(job_id)
+        if validate_job_detail['Status'] == 'Succeeded':
+            pipeline_run_id = validate_job_detail['RunId']
+            node_id = validate_job_detail['PaiflowNodeId']
+            flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 2)
+            out_puts = flow_out_put_detail['Outputs']
+            table = None
+            for out_put in out_puts:
+                if out_put["Producer"] == node_dict['样本shuffle'] and out_put["Name"] == "outputTable":
+                    value1 = json.loads(out_put["Info"]['value'])
+                    table = value1['location']['table']
+            if table is not None:
+                update_shuffle_flow(table)
+                node_dict = get_node_dict()
+                train_node_id = node_dict['模型训练-样本shufle']
+                execute_type = 'EXECUTE_ONE'
+                train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
+                train_job_id = train_res['JobId']
+                train_job_detail = wait_job_end(train_job_id)
+                if train_job_detail['Status'] == 'Succeeded':
                     return True
-    return False
+        return False
+    except Exception as e:
+        error_message = f"在执行 shuffle_train_model 函数时发生异常: {str(e)}"
+        print(error_message)
+        raise Exception(error_message)
+
+
+@retry
+def export_model():
+    try:
+        node_dict = get_node_dict()
+        export_node_id = node_dict['模型导出-2']
+        execute_type = 'EXECUTE_ONE'
+        export_res = PAIClient.create_job(experiment_id, export_node_id, execute_type)
+        export_job_id = export_res['JobId']
+        export_job_detail = wait_job_end(export_job_id)
+        if export_job_detail['Status'] == 'Succeeded':
+            return True
+        return False
+    except Exception as e:
+        error_msg = f"在执行 export_model 函数时发生异常: {str(e)}"
+        raise Exception(error_msg)
 
 
 def update_online_model():
-    node_dict = get_node_dict()
-    train_node_id = node_dict['更新EAS服务(Beta)-1']
-    execute_type = 'EXECUTE_ONE'
-    train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
-    train_job_id = train_res['JobId']
-    train_job_detail = wait_job_end(train_job_id)
-    if train_job_detail['Status'] == 'Succeeded':
-        return True
-    return False
-
-
-def validate_model_data_accuracy(start_time):
-    node_dict = get_node_dict()
-    train_node_id = node_dict['虚拟起始节点']
-    execute_type = 'EXECUTE_FROM_HERE'
-    validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
-    validate_job_id = validate_res['JobId']
-    validate_job_detail = wait_job_end(validate_job_id)
-    if validate_job_detail['Status'] == 'Succeeded':
-        pipeline_run_id = validate_job_detail['RunId']
-        node_id = validate_job_detail['PaiflowNodeId']
-        flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 3)
-        print(flow_out_put_detail)
-        table_dict = {}
-        out_puts = flow_out_put_detail['Outputs']
-        for out_put in out_puts:
-            if out_put["Producer"] == node_dict['二分类评估-1'] and out_put["Name"] == "outputMetricTable":
-                value1 = json.loads(out_put["Info"]['value'])
-                table_dict['二分类评估-1'] = value1['location']['table']
-            if out_put["Producer"] == node_dict['二分类评估-2'] and out_put["Name"] == "outputMetricTable":
-                value2 = json.loads(out_put["Info"]['value'])
-                table_dict['二分类评估-2'] = value2['location']['table']
-            if out_put["Producer"] == node_dict['预测结果对比'] and out_put["Name"] == "outputTable":
-                value3 = json.loads(out_put["Info"]['value'])
-                table_dict['预测结果对比'] = value3['location']['table']
+    try:
+        node_dict = get_node_dict()
+        train_node_id = node_dict['更新EAS服务(Beta)-1']
+        execute_type = 'EXECUTE_ONE'
+        train_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
+        train_job_id = train_res['JobId']
+        train_job_detail = wait_job_end(train_job_id)
+        if train_job_detail['Status'] == 'Succeeded':
+            return True
+        return False
+    except Exception as e:
+        error_message = f"在执行 update_online_model 函数时发生异常: {str(e)}"
+        print(error_message)
+        raise Exception(error_message)
 
+
+@retry
+def get_validate_model_data():
+    try:
+        node_dict = get_node_dict()
+        train_node_id = node_dict['虚拟起始节点']
+        execute_type = 'EXECUTE_FROM_HERE'
+        validate_res = PAIClient.create_job(experiment_id, train_node_id, execute_type)
+        validate_job_id = validate_res['JobId']
+        validate_job_detail = wait_job_end(validate_job_id)
+        if validate_job_detail['Status'] == 'Succeeded':
+            return True
+        return False
+    except Exception as e:
+        error_message = f"在执行 get_validate_model_data 函数时出现异常: {e}"
+        print(error_message)
+        raise Exception(error_message)
+
+
+def validate_model_data_accuracy():
+    try:
+        table_dict = {}
+        node_dict = get_node_dict()
+        job_dict = get_job_dict()
+        job_id = job_dict['虚拟起始节点']
+        validate_job_detail = wait_job_end(job_id)
+        if validate_job_detail['Status'] == 'Succeeded':
+            pipeline_run_id = validate_job_detail['RunId']
+            node_id = validate_job_detail['PaiflowNodeId']
+            flow_out_put_detail = PAIClient.get_flow_out_put(pipeline_run_id, node_id, 3)
+            print(flow_out_put_detail)
+            out_puts = flow_out_put_detail['Outputs']
+            for out_put in out_puts:
+                if out_put["Producer"] == node_dict['二分类评估-1'] and out_put["Name"] == "outputMetricTable":
+                    value1 = json.loads(out_put["Info"]['value'])
+                    table_dict['二分类评估-1'] = value1['location']['table']
+                if out_put["Producer"] == node_dict['二分类评估-2'] and out_put["Name"] == "outputMetricTable":
+                    value2 = json.loads(out_put["Info"]['value'])
+                    table_dict['二分类评估-2'] = value2['location']['table']
+                if out_put["Producer"] == node_dict['预测结果对比'] and out_put["Name"] == "outputTable":
+                    value3 = json.loads(out_put["Info"]['value'])
+                    table_dict['预测结果对比'] = value3['location']['table']
         num = 10
         df = get_data_from_odps('pai_algo', table_dict['预测结果对比'], 10)
         # 对指定列取绝对值再求和
@@ -558,7 +668,7 @@ def validate_model_data_accuracy(start_time):
         bizdate = get_previous_days_date(1)
         score_diff = abs(old_abs_avg - new_abs_avg)
         msg = ""
-        level = ""
+        result = False
         if new_abs_avg > 0.1:
             msg += f'线上模型评估{bizdate}的数据,绝对误差大于0.1,请检查'
             level = 'error'
@@ -566,11 +676,9 @@ def validate_model_data_accuracy(start_time):
             msg += f'两个模型评估${bizdate}的数据,两个模型分数差异为: ${score_diff}, 大于0.05, 请检查'
             level = 'error'
         else:
-            # update_online_model()
             msg += 'DNN广告模型更新完成'
             level = 'info'
-        step_end_time = int(time.time())
-        elapsed = step_end_time - start_time
+            result = True
 
         # 初始化表格头部
         top10_msg = "| CID  | 老模型相对真实CTCVR的变化 | 新模型相对真实CTCVR的变化 |"
@@ -587,17 +695,28 @@ def validate_model_data_accuracy(start_time):
         msg += f"\n\t - 新模型AUC: {new_auc}"
         msg += f"\n\t - 老模型Top10差异平均值: {old_abs_avg}"
         msg += f"\n\t - 新模型Top10差异平均值: {new_abs_avg}"
-        _monitor(level, msg, start_time, elapsed, top10_msg)
+        return result, msg, level, top10_msg
+
+    except Exception as e:
+        error_message = f"在执行 validate_model_data_accuracy 函数时出现异常: {str(e)}"
+        print(error_message)
+        raise Exception(error_message)
 
 
 if __name__ == '__main__':
     start_time = int(time.time())
-    # 1.更新工作流
-    update_online_flow()
-    # 2.训练模型
-    train_res = train_model()
-    if train_res:
-        # 3. 验证模型数据 & 更新模型到线上
-        validate_model_data_accuracy(start_time)
+    functions = [shuffle_table, shuffle_train_model, export_model, get_validate_model_data]
+    for func in functions:
+        if not func():
+            print(f"{func.__name__} 执行失败,后续函数不再执行。")
+            break
     else:
-        print('train_model_error')
+        print("所有函数都成功执行,可以继续下一步操作。")
+        result, msg, level, top10_msg = validate_model_data_accuracy()
+        if result:
+            # update_online_model()
+            print("success")
+        step_end_time = int(time.time())
+        elapsed = step_end_time - start_time
+        print(level, msg, start_time, elapsed, top10_msg)
+        _monitor(level, msg, start_time, elapsed, top10_msg)