|
@@ -16,7 +16,7 @@ import paddle
|
|
|
import os
|
|
|
import logging
|
|
|
import numpy as np
|
|
|
-from paddle.static.io import _get_valid_program, normalize_program
|
|
|
+from paddle.static.io import _get_valid_program, normalize_program, program_guard
|
|
|
|
|
|
logging.basicConfig(
|
|
|
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
|
|
@@ -75,9 +75,31 @@ def save_inference_model(model_path,
|
|
|
model_path = os.path.join(model_path, str(epoch_id))
|
|
|
_mkdir_if_not_exist(model_path)
|
|
|
model_prefix = os.path.join(model_path, prefix)
|
|
|
-
|
|
|
+ fetch_vars2 = fetch_vars
|
|
|
|
|
|
program = _get_valid_program(None)
|
|
|
+
|
|
|
+ for op in program.global_block().ops:
|
|
|
+ # clear device of Op
|
|
|
+ device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
|
|
|
+ op._set_attr(device_attr_name, "")
|
|
|
+ if op.type == 'auc':
|
|
|
+ warnings.warn(
|
|
|
+ "Be sure that you have set auc states to 0 before saving inference model."
|
|
|
+ )
|
|
|
+ break
|
|
|
+
|
|
|
+ # fix the bug that the activation op's output as target will be pruned.
|
|
|
+ # will affect the inference performance.
|
|
|
+ # TODO(Superjomn) add an IR pass to remove 1-scale op.
|
|
|
+ with program_guard(program):
|
|
|
+ uniq_fetch_vars = []
|
|
|
+ for i, var in enumerate(fetch_vars):
|
|
|
+ if var.dtype != paddle.bool:
|
|
|
+ var = paddle.scale(var, 1.0, name=f"save_infer_model/scale_{i}")
|
|
|
+ uniq_fetch_vars.append(var)
|
|
|
+ fetch_vars = uniq_fetch_vars
|
|
|
+
|
|
|
copy_program = program.clone()
|
|
|
logger.info("program.clone(): {}".format(copy_program.global_block().vars.keys()))
|
|
|
|
|
@@ -122,7 +144,7 @@ def save_inference_model(model_path,
|
|
|
paddle.static.save_inference_model(
|
|
|
path_prefix=model_prefix,
|
|
|
feed_vars=feed_vars,
|
|
|
- fetch_vars=fetch_vars,
|
|
|
+ fetch_vars=fetch_vars2,
|
|
|
executor=exe)
|
|
|
return model_prefix
|
|
|
|