Browse Source

save model

丁云鹏 5 tháng trước cách đây
mục cha
commit
88cd9055e2

+ 33 - 3
recommend-model-produce/src/main/python/tools/static_ps_trainer_v2.py

@@ -232,6 +232,7 @@ class Main(object):
 
             model_dir = "{}/{}".format(save_model_path, epoch)
 
+
             if is_distributed_env():
                 fleet.save_inference_model(
                     self.exe, model_dir,
@@ -253,11 +254,40 @@ class Main(object):
                 self.exe,
                 model_filename=None,
                 params_filename=None)
-            
+
+            feed_var_names = [feed.name for feed in self.inference_feed_var]
+            feedvars = []
+            fetch_var_names = feed_var_names + [self.inference_target_var]
+            fetchvars = []
+
+            for var_name in feed_var_names:
+                if var_name not in paddle.static.default_main_program(
+                ).global_block().vars:
+                    raise ValueError(
+                        "Feed variable: {} not in default_main_program, global block has follow vars: {}".
+                        format(var_name,
+                               paddle.static.default_main_program()
+                               .global_block().vars.keys()))
+                else:
+                    feedvars.append(paddle.static.default_main_program()
+                                    .global_block().vars[var_name])
+            for var_name in fetch_var_names:
+                if var_name not in paddle.static.default_main_program(
+                ).global_block().vars:
+                    raise ValueError(
+                        "Fetch variable: {} not in default_main_program, global block has follow vars: {}".
+                        format(var_name,
+                               paddle.static.default_main_program()
+                               .global_block().vars.keys()))
+                else:
+                    fetchvars.append(paddle.static.default_main_program()
+                                     .global_block().vars[var_name])
+
             paddle.static.save_inference_model(
                 os.path.join(model_dir, "dnn_plugin_new"),
-                [feed.name for feed in self.inference_feed_var],
-                [self.inference_target_var], self.exe)
+                feedvars,
+                fetchvars, 
+                self.exe)
 
             compress.compress_tar(os.path.join(model_dir, "dnn_plugin_new"), "dnn_plugin_new.tar.gz")
             client = HangZhouOSSClient("art-recommend")