丁云鹏 5 ヶ月 前
コミット
44bcd602d2

+ 3 - 14
recommend-model-produce/src/main/python/tools/static_ps_trainer_v2.py

@@ -253,25 +253,14 @@ class Main(object):
                 model_filename=None,
                 params_filename=None)
 
-            feedvars = []
-            fetchvars = fetch_targets
-
-            for var_name in feed_target_names:
-                if var_name not in program.global_block().vars:
-                    raise ValueError(
-                        "Feed variable: {} not in default_main_program, global block has follow vars: {}".
-                        format(var_name,
-                               program.global_block().vars.keys()))
-                else:
-                    feedvars.append(program.global_block().vars[var_name])
             paddle.static.save_inference_model(
                 os.path.join(model_dir, "dnn_plugin_new/dssm"),
-                feedvars,
-                fetchvars, 
+                self.inference_feed_var,
+                [self.inference_target_var], 
                 self.exe,
                 program = program)
 
-            logger.info("program.global_block().vars.keys() {}".format(program.global_block().vars.keys()))
+            # logger.info("program.global_block().vars.keys() {}".format(program.global_block().vars.keys()))
 
             compress.compress_tar(os.path.join(model_dir, "dnn_plugin_new"), "dnn_plugin_new.tar.gz")
             client = HangZhouOSSClient("art-recommend")