|
@@ -249,45 +249,38 @@ class Main(object):
|
|
|
|
|
|
# trans to new format
|
|
|
# {"model_filename":"", "params_filename":""} fleet每个参数一个文件,需要同这种方式加载
|
|
|
- paddle.static.load_inference_model(
|
|
|
+ program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
|
|
|
os.path.join(model_dir, "dnn_plugin"),
|
|
|
self.exe,
|
|
|
model_filename=None,
|
|
|
params_filename=None)
|
|
|
|
|
|
- feed_var_names = [feed.name for feed in self.inference_feed_var]
|
|
|
feedvars = []
|
|
|
- fetch_var_names = feed_var_names + [self.inference_target_var]
|
|
|
fetchvars = []
|
|
|
|
|
|
- for var_name in feed_var_names:
|
|
|
- if var_name not in paddle.static.default_main_program(
|
|
|
- ).global_block().vars:
|
|
|
+ for var_name in feed_target_names:
|
|
|
+ if var_name not in program.global_block().vars:
|
|
|
raise ValueError(
|
|
|
"Feed variable: {} not in default_main_program, global block has follow vars: {}".
|
|
|
format(var_name,
|
|
|
- paddle.static.default_main_program()
|
|
|
- .global_block().vars.keys()))
|
|
|
+ program.global_block().vars.keys()))
|
|
|
else:
|
|
|
- feedvars.append(paddle.static.default_main_program()
|
|
|
- .global_block().vars[var_name])
|
|
|
- for var_name in fetch_var_names:
|
|
|
- if var_name not in paddle.static.default_main_program(
|
|
|
- ).global_block().vars:
|
|
|
+ feedvars.append(program.global_block().vars[var_name])
|
|
|
+ for var_name in fetch_targets:
|
|
|
+ if var_name not in program.global_block().vars:
|
|
|
raise ValueError(
|
|
|
"Fetch variable: {} not in default_main_program, global block has follow vars: {}".
|
|
|
format(var_name,
|
|
|
- paddle.static.default_main_program()
|
|
|
- .global_block().vars.keys()))
|
|
|
+ program.global_block().vars.keys()))
|
|
|
else:
|
|
|
- fetchvars.append(paddle.static.default_main_program()
|
|
|
- .global_block().vars[var_name])
|
|
|
+ fetchvars.append(program.global_block().vars[var_name])
|
|
|
|
|
|
paddle.static.save_inference_model(
|
|
|
os.path.join(model_dir, "dnn_plugin_new"),
|
|
|
feedvars,
|
|
|
fetchvars,
|
|
|
- self.exe)
|
|
|
+ self.exe,
|
|
|
+ program = program)
|
|
|
|
|
|
compress.compress_tar(os.path.join(model_dir, "dnn_plugin_new"), "dnn_plugin_new.tar.gz")
|
|
|
client = HangZhouOSSClient("art-recommend")
|