|
@@ -74,11 +74,24 @@ def save_inference_model(model_path,
|
|
|
model_path = os.path.join(model_path, str(epoch_id))
|
|
|
_mkdir_if_not_exist(model_path)
|
|
|
model_prefix = os.path.join(model_path, prefix)
|
|
|
+
|
|
|
+
|
|
|
+ program = paddle.static._get_valid_program(None)
|
|
|
+ program = paddle.static.normalize_program(
|
|
|
+ program,
|
|
|
+ feed_vars,
|
|
|
+ fetch_vars,
|
|
|
+ skip_prune_program=False)
|
|
|
+ )
|
|
|
+ logger.info("global block has follow vars: {}".format(program.global_block().vars.keys()))
|
|
|
+
|
|
|
+
|
|
|
paddle.static.save_inference_model(
|
|
|
path_prefix=model_prefix,
|
|
|
feed_vars=feed_vars,
|
|
|
fetch_vars=fetch_vars,
|
|
|
executor=exe)
|
|
|
+ return model_prefix
|
|
|
|
|
|
|
|
|
def load_static_model(program, model_path, prefix='rec_static'):
|