丁云鹏 6 ヶ月 前
コミット
9cece01667

+ 33 - 33
recommend-model-produce/src/main/python/tools/static_ps_trainer_v2.py

@@ -246,40 +246,40 @@ class Main(object):
                     [self.inference_target_var], self.exe)
 
 
+            if fleet.is_first_worker():
+                # trans to new format
+                # {"model_filename":"", "params_filename":""} fleet每个参数一个文件,需要同这种方式加载
+                program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
+                    os.path.join(model_dir, "dnn_plugin"),
+                    self.exe,
+                    model_filename=None,
+                    params_filename=None)
+
+                feedvars = []
+                fetchvars = fetch_targets
+
+                for var_name in feed_target_names:
+                    if var_name not in program.global_block().vars:
+                        raise ValueError(
+                            "Feed variable: {} not in default_main_program, global block has follow vars: {}".
+                            format(var_name,
+                                   program.global_block().vars.keys()))
+                    else:
+                        feedvars.append(program.global_block().vars[var_name])
 
-            # trans to new format
-            # {"model_filename":"", "params_filename":""} fleet每个参数一个文件,需要同这种方式加载
-            program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
-                os.path.join(model_dir, "dnn_plugin"),
-                self.exe,
-                model_filename=None,
-                params_filename=None)
-
-            feedvars = []
-            fetchvars = fetch_targets
-
-            for var_name in feed_target_names:
-                if var_name not in program.global_block().vars:
-                    raise ValueError(
-                        "Feed variable: {} not in default_main_program, global block has follow vars: {}".
-                        format(var_name,
-                               program.global_block().vars.keys()))
-                else:
-                    feedvars.append(program.global_block().vars[var_name])
-
-            paddle.static.save_inference_model(
-                os.path.join(model_dir, "dnn_plugin_new"),
-                feedvars,
-                fetchvars, 
-                self.exe,
-                program = program)
-
-            compress.compress_tar(os.path.join(model_dir, "dnn_plugin_new"), "dnn_plugin_new.tar.gz")
-            client = HangZhouOSSClient("art-recommend")
-            client.put_object_from_file("dyp/model.tar.gz", "dnn_plugin_new.tar.gz")
-            while True:
-                time.sleep(300)
-                continue;
+                paddle.static.save_inference_model(
+                    os.path.join(model_dir, "dnn_plugin_new"),
+                    feedvars,
+                    fetchvars, 
+                    self.exe,
+                    program = program)
+
+                compress.compress_tar(os.path.join(model_dir, "dnn_plugin_new/dssm"), "dssm.tar.gz")
+                client = HangZhouOSSClient("art-recommend")
+                client.put_object_from_file("dyp/dssm.tar.gz", "dssm.tar.gz")
+                while True:
+                    time.sleep(300)
+                    continue;
 
         if reader_type == "InmemoryDataset":
             self.reader.release_memory()