|
@@ -38,10 +38,7 @@ import utils.compress as compress
|
|
|
|
|
|
|
|
|
sys.path.append(os.path.abspath("") + os.sep + "lib")
|
|
|
-print(os.path.abspath("") + os.sep + "lib1")
|
|
|
import brpc_flags
|
|
|
-print(os.path.abspath("") + os.sep + "lib2")
|
|
|
-
|
|
|
|
|
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
|
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
|
@@ -245,41 +242,41 @@ class Main(object):
|
|
|
[feed.name for feed in self.inference_feed_var],
|
|
|
[self.inference_target_var], self.exe)
|
|
|
|
|
|
-
|
|
|
- if fleet.is_first_worker():
|
|
|
- # trans to new format
|
|
|
- # {"model_filename":"", "params_filename":""} fleet每个参数一个文件,需要同这种方式加载
|
|
|
- program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
|
|
|
- os.path.join(model_dir, "dnn_plugin"),
|
|
|
- self.exe,
|
|
|
- model_filename=None,
|
|
|
- params_filename=None)
|
|
|
-
|
|
|
- feedvars = []
|
|
|
- fetchvars = fetch_targets
|
|
|
-
|
|
|
- for var_name in feed_target_names:
|
|
|
- if var_name not in program.global_block().vars:
|
|
|
- raise ValueError(
|
|
|
- "Feed variable: {} not in default_main_program, global block has follow vars: {}".
|
|
|
- format(var_name,
|
|
|
- program.global_block().vars.keys()))
|
|
|
- else:
|
|
|
- feedvars.append(program.global_block().vars[var_name])
|
|
|
-
|
|
|
- paddle.static.save_inference_model(
|
|
|
- os.path.join(model_dir, "dnn_plugin_new/dssm"),
|
|
|
- feedvars,
|
|
|
- fetchvars,
|
|
|
- self.exe,
|
|
|
- program = program)
|
|
|
-
|
|
|
- compress.compress_tar(os.path.join(model_dir, "dnn_plugin_new"), "dnn_plugin_new.tar.gz")
|
|
|
- client = HangZhouOSSClient("art-recommend")
|
|
|
- client.put_object_from_file("dyp/dssm.tar.gz", "dnn_plugin_new.tar.gz")
|
|
|
- while True:
|
|
|
- time.sleep(300)
|
|
|
- continue;
|
|
|
+ if fleet.is_first_worker():
|
|
|
+ model_dir = "{}/{}".format(save_model_path, epochs)
|
|
|
+ oss_object_name = self.config.get("runner.oss_object_name")
|
|
|
+ # trans to new format
|
|
|
+ # {"model_filename":"", "params_filename":""} fleet每个参数一个文件,需要同这种方式加载
|
|
|
+ program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
|
|
|
+ os.path.join(model_dir, "dnn_plugin"),
|
|
|
+ self.exe,
|
|
|
+ model_filename=None,
|
|
|
+ params_filename=None)
|
|
|
+
|
|
|
+ feedvars = []
|
|
|
+ fetchvars = fetch_targets
|
|
|
+
|
|
|
+ for var_name in feed_target_names:
|
|
|
+ if var_name not in program.global_block().vars:
|
|
|
+ raise ValueError(
|
|
|
+ "Feed variable: {} not in default_main_program, global block has follow vars: {}".
|
|
|
+ format(var_name,
|
|
|
+ program.global_block().vars.keys()))
|
|
|
+ else:
|
|
|
+ feedvars.append(program.global_block().vars[var_name])
|
|
|
+ paddle.static.save_inference_model(
|
|
|
+ os.path.join(model_dir, "dnn_plugin_new/dssm"),
|
|
|
+ feedvars,
|
|
|
+ fetchvars,
|
|
|
+ self.exe,
|
|
|
+ program = program)
|
|
|
+
|
|
|
+ compress.compress_tar(os.path.join(model_dir, "dnn_plugin_new"), "dnn_plugin_new.tar.gz")
|
|
|
+ client = HangZhouOSSClient("art-recommend")
|
|
|
+ client.put_object_from_file(oss_object_name, "dnn_plugin_new.tar.gz")
|
|
|
+ while True:
|
|
|
+ time.sleep(300)
|
|
|
+ continue;
|
|
|
|
|
|
if reader_type == "InmemoryDataset":
|
|
|
self.reader.release_memory()
|