|
@@ -220,7 +220,28 @@ class Main(object):
|
|
|
self.train_result_dict["speed"].append(epoch_speed)
|
|
|
|
|
|
model_dir = "{}/{}".format(save_model_path, epoch)
|
|
|
+ if self.role.is_first_worker():
|
|
|
+ # 1. 确保所有 worker 同步
|
|
|
+ fleet.barrier_worker()
|
|
|
+
|
|
|
+ # 2. 获取主程序
|
|
|
+ main_program = fleet.main_program
|
|
|
+
|
|
|
+ # 3. 使用 paddle.static.save_inference_model 替代 fleet.save_inference_model
|
|
|
+ paddle.static.save_inference_model(
|
|
|
+ model_dir,
|
|
|
+ [feed.name for feed in self.inference_feed_var],
|
|
|
+ self.inference_target_var,
|
|
|
+ self.exe,
|
|
|
+ program=main_program, # 使用 fleet 的主程序
|
|
|
+ export_for_deployment=True # 保存为新格式
|
|
|
+ )
|
|
|
+
|
|
|
+ # 4. 再次同步确保保存完成
|
|
|
+ fleet.barrier_worker()
|
|
|
+
|
|
|
|
|
|
+ """
|
|
|
if is_distributed_env():
|
|
|
fleet.save_inference_model(
|
|
|
self.exe, model_dir,
|
|
@@ -232,7 +253,7 @@ class Main(object):
|
|
|
model_dir,
|
|
|
[feed.name for feed in self.inference_feed_var],
|
|
|
[self.inference_target_var], self.exe)
|
|
|
-
|
|
|
+ """
|
|
|
if reader_type == "InmemoryDataset":
|
|
|
self.reader.release_memory()
|
|
|
|