|
@@ -111,8 +111,11 @@ class Main(object):
|
|
|
|
|
|
def init_fleet_with_gloo(use_gloo=True):
|
|
def init_fleet_with_gloo(use_gloo=True):
|
|
if use_gloo:
|
|
if use_gloo:
|
|
- os.environ["PADDLE_WITH_GLOO"] = "1"
|
|
|
|
- role = role_maker.PaddleCloudRoleMaker()
|
|
|
|
|
|
+ os.environ["PADDLE_WITH_GLOO"] = "0"
|
|
|
|
+ role = role_maker.PaddleCloudRoleMaker(
|
|
|
|
+ is_collective=False,
|
|
|
|
+ init_gloo=False
|
|
|
|
+ )
|
|
fleet.init(role)
|
|
fleet.init(role)
|
|
else:
|
|
else:
|
|
fleet.init()
|
|
fleet.init()
|
|
@@ -193,7 +196,7 @@ class Main(object):
|
|
self.dataset_train_loop(epoch)
|
|
self.dataset_train_loop(epoch)
|
|
|
|
|
|
epoch_time = time.time() - epoch_start_time
|
|
epoch_time = time.time() - epoch_start_time
|
|
- epoch_speed = self.example_nums / epoch_time
|
|
|
|
|
|
+
|
|
if use_auc is True:
|
|
if use_auc is True:
|
|
global_auc = get_global_auc(paddle.static.global_scope(),
|
|
global_auc = get_global_auc(paddle.static.global_scope(),
|
|
self.model.stat_pos.name,
|
|
self.model.stat_pos.name,
|
|
@@ -208,15 +211,13 @@ class Main(object):
|
|
set_zero(self.model.batch_stat_neg.name,
|
|
set_zero(self.model.batch_stat_neg.name,
|
|
paddle.static.global_scope())
|
|
paddle.static.global_scope())
|
|
logger.info(
|
|
logger.info(
|
|
- "Epoch: {}, using time: {} second, ips: {} {}/sec. auc: {}".
|
|
|
|
- format(epoch, epoch_time, epoch_speed, self.count_method,
|
|
|
|
|
|
+ "Epoch: {}, using time: {} second, ips: {}/sec. auc: {}".
|
|
|
|
+ format(epoch, epoch_time, self.count_method,
|
|
global_auc))
|
|
global_auc))
|
|
else:
|
|
else:
|
|
logger.info(
|
|
logger.info(
|
|
- "Epoch: {}, using time {} second, ips {} {}/sec.".format(
|
|
|
|
- epoch, epoch_time, epoch_speed, self.count_method))
|
|
|
|
-
|
|
|
|
- self.train_result_dict["speed"].append(epoch_speed)
|
|
|
|
|
|
+ "Epoch: {}, using time {} second, ips {}/sec.".format(
|
|
|
|
+ epoch, epoch_time, self.count_method))
|
|
|
|
|
|
model_dir = "{}/{}".format(save_model_path, epoch)
|
|
model_dir = "{}/{}".format(save_model_path, epoch)
|
|
|
|
|
|
@@ -232,22 +233,17 @@ class Main(object):
|
|
self.example_nums = 0
|
|
self.example_nums = 0
|
|
self.count_method = self.config.get("runner.example_count_method",
|
|
self.count_method = self.config.get("runner.example_count_method",
|
|
"example")
|
|
"example")
|
|
- if self.count_method == "example":
|
|
|
|
- self.example_nums = get_example_num(self.file_list)
|
|
|
|
- elif self.count_method == "word":
|
|
|
|
- self.example_nums = get_word_num(self.file_list)
|
|
|
|
- else:
|
|
|
|
- raise ValueError(
|
|
|
|
- "Set static_benchmark.example_count_method for example / word for example count."
|
|
|
|
- )
|
|
|
|
|
|
|
|
def dataset_train_loop(self, epoch):
|
|
def dataset_train_loop(self, epoch):
|
|
logger.info("Epoch: {}, Running Dataset Begin.".format(epoch))
|
|
logger.info("Epoch: {}, Running Dataset Begin.".format(epoch))
|
|
|
|
+
|
|
fetch_info = [
|
|
fetch_info = [
|
|
"Epoch {} Var {}".format(epoch, var_name)
|
|
"Epoch {} Var {}".format(epoch, var_name)
|
|
for var_name in self.metrics
|
|
for var_name in self.metrics
|
|
]
|
|
]
|
|
|
|
+
|
|
fetch_vars = [var for _, var in self.metrics.items()]
|
|
fetch_vars = [var for _, var in self.metrics.items()]
|
|
|
|
+
|
|
print_step = int(config.get("runner.print_interval"))
|
|
print_step = int(config.get("runner.print_interval"))
|
|
|
|
|
|
debug = config.get("runner.dataset_debug", False)
|
|
debug = config.get("runner.dataset_debug", False)
|
|
@@ -268,6 +264,7 @@ class Main(object):
|
|
print_period=print_step,
|
|
print_period=print_step,
|
|
debug=debug)
|
|
debug=debug)
|
|
|
|
|
|
|
|
+
|
|
def heter_train_loop(self, epoch):
|
|
def heter_train_loop(self, epoch):
|
|
logger.info(
|
|
logger.info(
|
|
"Epoch: {}, Running Begin. Check running metrics at heter_log".
|
|
"Epoch: {}, Running Begin. Check running metrics at heter_log".
|
|
@@ -318,7 +315,9 @@ class Main(object):
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
paddle.enable_static()
|
|
paddle.enable_static()
|
|
|
|
+
|
|
config = parse_args()
|
|
config = parse_args()
|
|
os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))
|
|
os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))
|
|
benchmark_main = Main(config)
|
|
benchmark_main = Main(config)
|
|
|
|
+
|
|
benchmark_main.run()
|
|
benchmark_main.run()
|