Forráskód Böngészése

修改模型更新脚本全局变量

xueyiming 5 napja
szülő
commit
908b38f403
1 módosított fájl, 42 hozzáadás és 36 törlés
  1. 42 36
      ad/pai_flow_operator2.py

+ 42 - 36
ad/pai_flow_operator2.py

@@ -387,13 +387,16 @@ def extract_date_yyyymmdd(input_string):
     return None
 
 
-def get_online_version_dt(service_name: str):
+def get_online_model_config(service_name: str):
+    model_config = {}
     model_detail = PAIClient.get_describe_service(service_name)
     service_config_str = model_detail['ServiceConfig']
     service_config = json.loads(service_config_str)
     model_path = service_config['model_path']
+    model_config['model_path'] = model_path
     online_date = extract_date_yyyymmdd(model_path)
-    return online_date
+    model_config['online_date'] = online_date
+    return model_config
 
 
 
@@ -493,7 +496,7 @@ def get_job_dict():
 @retry
 def update_online_flow():
     try:
-        online_version_dt = get_online_version_dt('ad_rank_dnn_v11_easyrec')
+        online_model_config = get_online_model_config('ad_rank_dnn_v11_easyrec')
         draft = PAIClient.get_work_flow_draft(experiment_id)
         print(json.dumps(draft, ensure_ascii=False))
         content = draft['Content']
@@ -508,9 +511,11 @@ def update_online_flow():
                 if global_param['name'] == 'bizdate':
                     global_param['value'] = bizdate
                 if global_param['name'] == 'online_version_dt':
-                    global_param['value'] = online_version_dt
+                    global_param['value'] = online_model_config['online_date']
                 if global_param['name'] == 'eval_date':
                     global_param['value'] = bizdate
+                if global_param['name'] == 'online_model_path':
+                    global_param['value'] = online_model_config['model_path']
             except KeyError:
                 raise Exception("在处理全局参数时,字典中缺少必要的键")
         for node in nodes:
@@ -709,35 +714,36 @@ def validate_model_data_accuracy():
 
 
 if __name__ == '__main__':
-    start_time = int(time.time())
-    functions = [update_online_flow, shuffle_table, shuffle_train_model, export_model, get_validate_model_data]
-    function_names = [func.__name__ for func in functions]
-
-    start_function = None
-    if len(sys.argv) > 1:
-        start_function = sys.argv[1]
-        if start_function not in function_names:
-            print(f"指定的起始函数 {start_function} 不存在,请选择以下函数之一:{', '.join(function_names)}")
-            sys.exit(1)
-
-    start_index = 0
-    if start_function:
-        start_index = function_names.index(start_function)
-
-    for func in functions[start_index:]:
-        if not func():
-            print(f"{func.__name__} 执行失败,后续函数不再执行。")
-            step_end_time = int(time.time())
-            elapsed = step_end_time - start_time
-            _monitor('error', f"DNN模型更新,{func.__name__} 执行失败,后续函数不再执行,请检查", start_time, elapsed, None)
-            break
-    else:
-        print("所有函数都成功执行,可以继续下一步操作。")
-        result, msg, level, top10_msg = validate_model_data_accuracy()
-        if result:
-            # update_online_model()
-            print("success")
-        step_end_time = int(time.time())
-        elapsed = step_end_time - start_time
-        print(level, msg, start_time, elapsed, top10_msg)
-        _monitor(level, msg, start_time, elapsed, top10_msg)
+    update_online_flow()
+    # start_time = int(time.time())
+    # functions = [update_online_flow, shuffle_table, shuffle_train_model, export_model, get_validate_model_data]
+    # function_names = [func.__name__ for func in functions]
+    #
+    # start_function = None
+    # if len(sys.argv) > 1:
+    #     start_function = sys.argv[1]
+    #     if start_function not in function_names:
+    #         print(f"指定的起始函数 {start_function} 不存在,请选择以下函数之一:{', '.join(function_names)}")
+    #         sys.exit(1)
+    #
+    # start_index = 0
+    # if start_function:
+    #     start_index = function_names.index(start_function)
+    #
+    # for func in functions[start_index:]:
+    #     if not func():
+    #         print(f"{func.__name__} 执行失败,后续函数不再执行。")
+    #         step_end_time = int(time.time())
+    #         elapsed = step_end_time - start_time
+    #         _monitor('error', f"DNN模型更新,{func.__name__} 执行失败,后续函数不再执行,请检查", start_time, elapsed, None)
+    #         break
+    # else:
+    #     print("所有函数都成功执行,可以继续下一步操作。")
+    #     result, msg, level, top10_msg = validate_model_data_accuracy()
+    #     if result:
+    #         # update_online_model()
+    #         print("success")
+    #     step_end_time = int(time.time())
+    #     elapsed = step_end_time - start_time
+    #     print(level, msg, start_time, elapsed, top10_msg)
+    #     _monitor(level, msg, start_time, elapsed, top10_msg)