丁云鹏 5 月之前
父節點
當前提交
08563909a5

+ 1 - 2
recommend-model-produce/src/main/python/tools/static_trainer.py

@@ -226,8 +226,7 @@ def main(args):
             fetchvars = []
 
             for op in paddle.static.default_main_program().global_block().ops:
-                for name in op.input_arg_names:
-                    logger.info("op.input_arg_names {}".format(name))
+                logger.info("op.input_arg_names {}".format(op.input_arg_names))
             
 
             for var_name in feed_var_names:

+ 13 - 0
recommend-model-produce/src/main/python/tools/utils/save_load.py

@@ -74,11 +74,24 @@ def save_inference_model(model_path,
     model_path = os.path.join(model_path, str(epoch_id))
     _mkdir_if_not_exist(model_path)
     model_prefix = os.path.join(model_path, prefix)
+
+
+    program = paddle.static._get_valid_program(None)
+    program = paddle.static.normalize_program(
+        program,
+        feed_vars,
+        fetch_vars,
+        skip_prune_program=False)
+    )
+    logger.info("global block has follow vars: {}".format(program.global_block().vars.keys()))
+
+
     paddle.static.save_inference_model(
         path_prefix=model_prefix,
         feed_vars=feed_vars,
         fetch_vars=fetch_vars,
         executor=exe)
+    return model_prefix
 
 
 def load_static_model(program, model_path, prefix='rec_static'):