|
@@ -232,6 +232,7 @@ class Main(object):
|
|
|
|
|
|
model_dir = "{}/{}".format(save_model_path, epoch)
|
|
|
|
|
|
+
|
|
|
if is_distributed_env():
|
|
|
fleet.save_inference_model(
|
|
|
self.exe, model_dir,
|
|
@@ -253,11 +254,40 @@ class Main(object):
|
|
|
self.exe,
|
|
|
model_filename=None,
|
|
|
params_filename=None)
|
|
|
-
|
|
|
+
|
|
|
+ feed_var_names = [feed.name for feed in self.inference_feed_var]
|
|
|
+ feedvars = []
|
|
|
+ fetch_var_names = feed_var_names + [self.inference_target_var]
|
|
|
+ fetchvars = []
|
|
|
+
|
|
|
+ for var_name in feed_var_names:
|
|
|
+ if var_name not in paddle.static.default_main_program(
|
|
|
+ ).global_block().vars:
|
|
|
+ raise ValueError(
|
|
|
+ "Feed variable: {} not in default_main_program, global block has follow vars: {}".
|
|
|
+ format(var_name,
|
|
|
+ paddle.static.default_main_program()
|
|
|
+ .global_block().vars.keys()))
|
|
|
+ else:
|
|
|
+ feedvars.append(paddle.static.default_main_program()
|
|
|
+ .global_block().vars[var_name])
|
|
|
+ for var_name in fetch_var_names:
|
|
|
+ if var_name not in paddle.static.default_main_program(
|
|
|
+ ).global_block().vars:
|
|
|
+ raise ValueError(
|
|
|
+ "Fetch variable: {} not in default_main_program, global block has follow vars: {}".
|
|
|
+ format(var_name,
|
|
|
+ paddle.static.default_main_program()
|
|
|
+ .global_block().vars.keys()))
|
|
|
+ else:
|
|
|
+ fetchvars.append(paddle.static.default_main_program()
|
|
|
+ .global_block().vars[var_name])
|
|
|
+
|
|
|
paddle.static.save_inference_model(
|
|
|
os.path.join(model_dir, "dnn_plugin_new"),
|
|
|
- [feed.name for feed in self.inference_feed_var],
|
|
|
- [self.inference_target_var], self.exe)
|
|
|
+ feedvars,
|
|
|
+ fetchvars,
|
|
|
+ self.exe)
|
|
|
|
|
|
compress.compress_tar(os.path.join(model_dir, "dnn_plugin_new"), "dnn_plugin_new.tar.gz")
|
|
|
client = HangZhouOSSClient("art-recommend")
|