often 5 months ago
parent
commit
0f3110a64b
1 changed files with 1 additions and 22 deletions
  1. 1 22
      recommend-model-produce/src/main/python/tools/static_ps_trainer.py

+ 1 - 22
recommend-model-produce/src/main/python/tools/static_ps_trainer.py

@@ -220,28 +220,7 @@ class Main(object):
             self.train_result_dict["speed"].append(epoch_speed)
 
             model_dir = "{}/{}".format(save_model_path, epoch)
-            if paddle.distributed.get_rank() == 0:
-                # 1. 确保所有 worker 同步
-                fleet.barrier_worker()
-                
-                # 2. 获取主程序
-                main_program = paddle.static.default_main_program()
-                
-                # 3. 使用 paddle.static.save_inference_model 替代 fleet.save_inference_model
-                paddle.static.save_inference_model(
-                    model_dir,
-                    [feed.name for feed in self.inference_feed_var],
-                    self.inference_target_var,
-                    self.exe,
-                    program=main_program,  # 使用 fleet 的主程序
-                    export_for_deployment=True  # 保存为新格式
-                )
-                
-                # 4. 再次同步确保保存完成
-                fleet.barrier_worker()
-
 
-            """
             if is_distributed_env():
                 fleet.save_inference_model(
                     self.exe, model_dir,
@@ -253,7 +232,7 @@ class Main(object):
                     model_dir,
                     [feed.name for feed in self.inference_feed_var],
                     [self.inference_target_var], self.exe)
-            """
+
         if reader_type == "InmemoryDataset":
             self.reader.release_memory()