|
@@ -76,66 +76,11 @@ def save_inference_model(model_path,
|
|
|
model_path = os.path.join(model_path, str(epoch_id))
|
|
|
_mkdir_if_not_exist(model_path)
|
|
|
model_prefix = os.path.join(model_path, prefix)
|
|
|
- fetch_vars2 = fetch_vars
|
|
|
-
|
|
|
- program = _get_valid_program(None)
|
|
|
-
|
|
|
- # fix the bug that the activation op's output as target will be pruned.
|
|
|
- # will affect the inference performance.
|
|
|
- # TODO(Superjomn) add an IR pass to remove 1-scale op.
|
|
|
- with program_guard(program):
|
|
|
- uniq_fetch_vars = []
|
|
|
- for i, var in enumerate(fetch_vars):
|
|
|
- if var.dtype != paddle.bool:
|
|
|
- var = paddle.scale(var, 1.0, name=f"save_infer_model/scale_{i}")
|
|
|
- uniq_fetch_vars.append(var)
|
|
|
- fetch_vars = uniq_fetch_vars
|
|
|
-
|
|
|
- copy_program = program.clone()
|
|
|
- logger.info("program.clone(): {}".format(copy_program.global_block().vars.keys()))
|
|
|
-
|
|
|
- global_block = copy_program.global_block()
|
|
|
- remove_op_idx = []
|
|
|
- for i, op in enumerate(global_block.ops):
|
|
|
- op.desc.set_is_target(False)
|
|
|
- if op.type == "feed" or op.type == "fetch":
|
|
|
- remove_op_idx.append(i)
|
|
|
-
|
|
|
- if op.type == "pylayer":
|
|
|
- sub_blocks_ids = op._blocks_attr_ids("blocks")
|
|
|
- if len(sub_blocks_ids) > 1:
|
|
|
- # pylayer op ``blocks`` attr contains forward block id and backward block id
|
|
|
- backward_block_id = sub_blocks_ids[-1]
|
|
|
- # remove backward block
|
|
|
- copy_program.blocks.pop(backward_block_id)
|
|
|
- # update attrs ``blocks``
|
|
|
- reserverd_blocks = []
|
|
|
- for block_id in sub_blocks_ids[:-1]:
|
|
|
- reserverd_blocks.append(copy_program.block(block_id))
|
|
|
- op._update_desc_attr("blocks", reserverd_blocks)
|
|
|
-
|
|
|
- for idx in remove_op_idx[::-1]:
|
|
|
- global_block._remove_op(idx)
|
|
|
- copy_program.desc.flush()
|
|
|
- logger.info("copy_program.desc.flush(): {}".format(copy_program.global_block().vars.keys()))
|
|
|
-
|
|
|
- feed_var_names = [var.name for var in feed_vars]
|
|
|
-
|
|
|
- skip_prune_program = False
|
|
|
- if not skip_prune_program:
|
|
|
- copy_program = copy_program._prune_with_input(
|
|
|
- feeded_var_names=feed_var_names, targets=fetch_vars
|
|
|
- )
|
|
|
- logger.info("copy_program._prune_with_input(): {}".format(copy_program.global_block().vars.keys()))
|
|
|
- copy_program = copy_program._inference_optimize(prune_read_op=True)
|
|
|
- logger.info("copy_program._inference_optimize(): {}".format(copy_program.global_block().vars.keys()))
|
|
|
-
|
|
|
-
|
|
|
|
|
|
paddle.static.save_inference_model(
|
|
|
path_prefix=model_prefix,
|
|
|
feed_vars=feed_vars,
|
|
|
- fetch_vars=fetch_vars2,
|
|
|
+ fetch_vars=fetch_vars,
|
|
|
executor=exe)
|
|
|
return model_prefix
|
|
|
|