|
@@ -220,7 +220,7 @@ class Main(object):
|
|
self.train_result_dict["speed"].append(epoch_speed)
|
|
self.train_result_dict["speed"].append(epoch_speed)
|
|
|
|
|
|
model_dir = "{}/{}".format(save_model_path, epoch)
|
|
model_dir = "{}/{}".format(save_model_path, epoch)
|
|
- if self.role.is_first_worker():
|
|
|
|
|
|
+ if paddle.distributed.get_rank() == 0:
|
|
# 1. 确保所有 worker 同步
|
|
# 1. 确保所有 worker 同步
|
|
fleet.barrier_worker()
|
|
fleet.barrier_worker()
|
|
|
|
|