Browse Source

save model

丁云鹏 5 months ago
parent
commit
7b9e93d2dc
1 changed files with 11 additions and 18 deletions
  1. 11 18
      recommend-model-produce/src/main/python/tools/static_ps_trainer_v2.py

+ 11 - 18
recommend-model-produce/src/main/python/tools/static_ps_trainer_v2.py

@@ -249,45 +249,38 @@ class Main(object):
 
             # trans to new format
             # {"model_filename":"", "params_filename":""} fleet每个参数一个文件,需要同这种方式加载
-            paddle.static.load_inference_model(
+            program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
                 os.path.join(model_dir, "dnn_plugin"),
                 self.exe,
                 model_filename=None,
                 params_filename=None)
 
-            feed_var_names = [feed.name for feed in self.inference_feed_var]
             feedvars = []
-            fetch_var_names = feed_var_names + [self.inference_target_var]
             fetchvars = []
 
-            for var_name in feed_var_names:
-                if var_name not in paddle.static.default_main_program(
-                ).global_block().vars:
+            for var_name in feed_target_names:
+                if var_name not in program.global_block().vars:
                     raise ValueError(
                         "Feed variable: {} not in default_main_program, global block has follow vars: {}".
                         format(var_name,
-                               paddle.static.default_main_program()
-                               .global_block().vars.keys()))
+                               program.global_block().vars.keys()))
                 else:
-                    feedvars.append(paddle.static.default_main_program()
-                                    .global_block().vars[var_name])
-            for var_name in fetch_var_names:
-                if var_name not in paddle.static.default_main_program(
-                ).global_block().vars:
+                    feedvars.append(program.global_block().vars[var_name])
+            for var_name in fetch_targets:
+                if var_name not in program.global_block().vars:
                     raise ValueError(
                         "Fetch variable: {} not in default_main_program, global block has follow vars: {}".
                         format(var_name,
-                               paddle.static.default_main_program()
-                               .global_block().vars.keys()))
+                               program.global_block().vars.keys()))
                 else:
-                    fetchvars.append(paddle.static.default_main_program()
-                                     .global_block().vars[var_name])
+                    fetchvars.append(program.global_block().vars[var_name])
 
             paddle.static.save_inference_model(
                 os.path.join(model_dir, "dnn_plugin_new"),
                 feedvars,
                 fetchvars, 
-                self.exe)
+                self.exe,
+                program = program)
 
             compress.compress_tar(os.path.join(model_dir, "dnn_plugin_new"), "dnn_plugin_new.tar.gz")
             client = HangZhouOSSClient("art-recommend")