|
@@ -21,7 +21,7 @@ from paddle.static.io import _get_valid_program, normalize_program
|
|
|
logging.basicConfig(
|
|
|
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
-
|
|
|
+logger.setLevel(logging.INFO)
|
|
|
|
|
|
def save_model(net, optimizer, model_path, epoch_id, prefix='rec'):
|
|
|
model_path = os.path.join(model_path, str(epoch_id))
|
|
@@ -78,12 +78,45 @@ def save_inference_model(model_path,
|
|
|
|
|
|
|
|
|
program = _get_valid_program(None)
|
|
|
- program = normalize_program(
|
|
|
- program,
|
|
|
- feed_vars,
|
|
|
- fetch_vars,
|
|
|
- skip_prune_program=True)
|
|
|
- logger.info("global block has follow vars: {}".format(program.global_block().vars.keys()))
|
|
|
+ copy_program = program.clone()
|
|
|
+ logger.info("program.clone(): {}".format(program.global_block().vars.keys()))
|
|
|
+
|
|
|
+ global_block = copy_program.global_block()
|
|
|
+ remove_op_idx = []
|
|
|
+ for i, op in enumerate(global_block.ops):
|
|
|
+ op.desc.set_is_target(False)
|
|
|
+ if op.type == "feed" or op.type == "fetch":
|
|
|
+ remove_op_idx.append(i)
|
|
|
+
|
|
|
+ if op.type == "pylayer":
|
|
|
+ sub_blocks_ids = op._blocks_attr_ids("blocks")
|
|
|
+ if len(sub_blocks_ids) > 1:
|
|
|
+
|
|
|
+ backward_block_id = sub_blocks_ids[-1]
|
|
|
+
|
|
|
+ copy_program.blocks.pop(backward_block_id)
|
|
|
+
|
|
|
+ reserverd_blocks = []
|
|
|
+ for block_id in sub_blocks_ids[:-1]:
|
|
|
+ reserverd_blocks.append(copy_program.block(block_id))
|
|
|
+ op._update_desc_attr("blocks", reserverd_blocks)
|
|
|
+
|
|
|
+ for idx in remove_op_idx[::-1]:
|
|
|
+ global_block._remove_op(idx)
|
|
|
+ copy_program.desc.flush()
|
|
|
+ logger.info("copy_program.desc.flush(): {}".format(program.global_block().vars.keys()))
|
|
|
+
|
|
|
+ feed_var_names = [var.name for var in feed_vars]
|
|
|
+
|
|
|
+ skip_prune_program = False
|
|
|
+ if not skip_prune_program:
|
|
|
+ copy_program = copy_program._prune_with_input(
|
|
|
+ feeded_var_names=feed_var_names, targets=fetch_vars
|
|
|
+ )
|
|
|
+ logger.info("copy_program._prune_with_input(): {}".format(program.global_block().vars.keys()))
|
|
|
+ copy_program = copy_program._inference_optimize(prune_read_op=True)
|
|
|
+ logger.info("copy_program._inference_optimize(): {}".format(program.global_block().vars.keys()))
|
|
|
+
|
|
|
|
|
|
|
|
|
paddle.static.save_inference_model(
|