@@ -220,7 +220,7 @@ class Main(object):
self.train_result_dict["speed"].append(epoch_speed)
model_dir = "{}/{}".format(save_model_path, epoch)
- if self.role.is_first_worker():
+ if paddle.distributed.get_rank() == 0:
# 1. 确保所有 worker 同步
fleet.barrier_worker()