|
@@ -122,7 +122,7 @@ class Main(object):
|
|
|
is_collective=False,
|
|
|
init_gloo=False
|
|
|
)
|
|
|
- fleet.init(role,config=fleet_config)
|
|
|
+ fleet.init(role)
|
|
|
#logger.info("worker_index: %s", fleet.worker_index())
|
|
|
#logger.info("is_first_worker: %s", fleet.is_first_worker())
|
|
|
#logger.info("worker_num: %s", fleet.worker_num())
|
|
@@ -131,9 +131,8 @@ class Main(object):
|
|
|
|
|
|
else:
|
|
|
# 在Fleet初始化配置中添加以下参数
|
|
|
-
|
|
|
-
|
|
|
- fleet.init(config=fleet_config)
|
|
|
+ fleet.init()
|
|
|
+ fleet.set_fleet_desc(fleet_config)
|
|
|
|
|
|
def network(self):
|
|
|
self.model = get_model(self.config)
|