often 5 months ago
parent
commit
b5ea3e92c4
1 changed files with 22 additions and 1 deletions
  1. 22 1
      recommend-model-produce/src/main/python/tools/static_ps_trainer.py

+ 22 - 1
recommend-model-produce/src/main/python/tools/static_ps_trainer.py

@@ -220,7 +220,28 @@ class Main(object):
             self.train_result_dict["speed"].append(epoch_speed)
 
             model_dir = "{}/{}".format(save_model_path, epoch)
+            if self.role.is_first_worker():
+                # 1. 确保所有 worker 同步
+                fleet.barrier_worker()
+                
+                # 2. 获取主程序
+                main_program = fleet.main_program
+                
+                # 3. 使用 paddle.static.save_inference_model 替代 fleet.save_inference_model
+                paddle.static.save_inference_model(
+                    model_dir,
+                    [feed.name for feed in self.inference_feed_var],
+                    self.inference_target_var,
+                    self.exe,
+                    program=main_program,  # 使用 fleet 的主程序
+                    export_for_deployment=True  # 保存为新格式
+                )
+                
+                # 4. 再次同步确保保存完成
+                fleet.barrier_worker()
+
 
+            """
             if is_distributed_env():
                 fleet.save_inference_model(
                     self.exe, model_dir,
@@ -232,7 +253,7 @@ class Main(object):
                     model_dir,
                     [feed.name for feed in self.inference_feed_var],
                     [self.inference_target_var], self.exe)
-
+            """
         if reader_type == "InmemoryDataset":
             self.reader.release_memory()