|
@@ -220,28 +220,7 @@ class Main(object):
|
|
|
self.train_result_dict["speed"].append(epoch_speed)
|
|
|
|
|
|
model_dir = "{}/{}".format(save_model_path, epoch)
|
|
|
- if paddle.distributed.get_rank() == 0:
|
|
|
- # 1. 确保所有 worker 同步
|
|
|
- fleet.barrier_worker()
|
|
|
-
|
|
|
- # 2. 获取主程序
|
|
|
- main_program = paddle.static.default_main_program()
|
|
|
-
|
|
|
- # 3. 使用 paddle.static.save_inference_model 替代 fleet.save_inference_model
|
|
|
- paddle.static.save_inference_model(
|
|
|
- model_dir,
|
|
|
- [feed.name for feed in self.inference_feed_var],
|
|
|
- self.inference_target_var,
|
|
|
- self.exe,
|
|
|
- program=main_program, # 使用 fleet 的主程序
|
|
|
- export_for_deployment=True # 保存为新格式
|
|
|
- )
|
|
|
-
|
|
|
- # 4. 再次同步确保保存完成
|
|
|
- fleet.barrier_worker()
|
|
|
-
|
|
|
|
|
|
- """
|
|
|
if is_distributed_env():
|
|
|
fleet.save_inference_model(
|
|
|
self.exe, model_dir,
|
|
@@ -253,7 +232,7 @@ class Main(object):
|
|
|
model_dir,
|
|
|
[feed.name for feed in self.inference_feed_var],
|
|
|
[self.inference_target_var], self.exe)
|
|
|
- """
|
|
|
+
|
|
|
if reader_type == "InmemoryDataset":
|
|
|
self.reader.release_memory()
|
|
|
|