trainer.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. os.environ['FLAGS_enable_pir_api'] = '0'
  16. import paddle
  17. import paddle.nn as nn
  18. import time
  19. import logging
  20. import sys
  21. import importlib
  22. __dir__ = os.path.dirname(os.path.abspath(__file__))
  23. #sys.path.append(__dir__)
  24. sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
  25. from utils.utils_single import load_yaml, load_dy_model_class, get_abs_model, create_data_loader
  26. from utils.save_load import load_model, save_model
  27. from paddle.io import DistributedBatchSampler, DataLoader
  28. import argparse
  29. logging.basicConfig(
  30. format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
  31. logger = logging.getLogger(__name__)
  32. logger.setLevel(logging.INFO)
  33. def parse_args():
  34. parser = argparse.ArgumentParser(description='paddle-rec run')
  35. parser.add_argument("-m", "--config_yaml", type=str)
  36. parser.add_argument("-o", "--opt", nargs='*', type=str)
  37. args = parser.parse_args()
  38. args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
  39. args.config_yaml = get_abs_model(args.config_yaml)
  40. return args
  41. def main(args):
  42. # load config
  43. config = load_yaml(args.config_yaml)
  44. dy_model_class = load_dy_model_class(args.abs_dir)
  45. config["config_abs_dir"] = args.abs_dir
  46. # modify config from command
  47. if args.opt:
  48. for parameter in args.opt:
  49. parameter = parameter.strip()
  50. key, value = parameter.split("=")
  51. if type(config.get(key)) is int:
  52. value = int(value)
  53. if type(config.get(key)) is float:
  54. value = float(value)
  55. if type(config.get(key)) is bool:
  56. value = (True if value.lower() == "true" else False)
  57. config[key] = value
  58. # tools.vars
  59. use_gpu = config.get("runner.use_gpu", True)
  60. use_auc = config.get("runner.use_auc", False)
  61. use_npu = config.get("runner.use_npu", False)
  62. use_xpu = config.get("runner.use_xpu", False)
  63. use_visual = config.get("runner.use_visual", False)
  64. train_data_dir = config.get("runner.train_data_dir", None)
  65. epochs = config.get("runner.epochs", None)
  66. print_interval = config.get("runner.print_interval", None)
  67. train_batch_size = config.get("runner.train_batch_size", None)
  68. model_save_path = config.get("runner.model_save_path", "model_output")
  69. model_init_path = config.get("runner.model_init_path", None)
  70. use_fleet = config.get("runner.use_fleet", False)
  71. seed = config.get("runner.seed", 12345)
  72. paddle.seed(seed)
  73. logger.info("**************common.configs**********")
  74. logger.info(
  75. "use_gpu: {}, use_xpu: {}, use_npu: {}, use_visual: {}, train_batch_size: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}".
  76. format(use_gpu, use_xpu, use_npu, use_visual, train_batch_size,
  77. train_data_dir, epochs, print_interval, model_save_path))
  78. logger.info("**************common.configs**********")
  79. if use_xpu:
  80. xpu_device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
  81. place = paddle.set_device(xpu_device)
  82. elif use_npu:
  83. npu_device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0))
  84. place = paddle.set_device(npu_device)
  85. else:
  86. place = paddle.set_device('gpu' if use_gpu else 'cpu')
  87. dy_model = dy_model_class.create_model(config)
  88. # Create a log_visual object and store the data in the path
  89. if use_visual:
  90. from visualdl import LogWriter
  91. log_visual = LogWriter(args.abs_dir + "/visualDL_log/train")
  92. if model_init_path is not None:
  93. load_model(model_init_path, dy_model)
  94. # to do : add optimizer function
  95. optimizer = dy_model_class.create_optimizer(dy_model, config)
  96. # use fleet run collective
  97. if use_fleet:
  98. from paddle.distributed import fleet
  99. strategy = fleet.DistributedStrategy()
  100. fleet.init(is_collective=True, strategy=strategy)
  101. optimizer = fleet.distributed_optimizer(optimizer)
  102. dy_model = fleet.distributed_model(dy_model)
  103. logger.info("read data")
  104. train_dataloader = create_data_loader(config=config, place=place)
  105. last_epoch_id = config.get("last_epoch", -1)
  106. step_num = 0
  107. for epoch_id in range(last_epoch_id + 1, epochs):
  108. # set train mode
  109. dy_model.train()
  110. metric_list, metric_list_name = dy_model_class.create_metrics()
  111. #auc_metric = paddle.metric.Auc("ROC")
  112. epoch_begin = time.time()
  113. interval_begin = time.time()
  114. train_reader_cost = 0.0
  115. train_run_cost = 0.0
  116. total_samples = 0
  117. reader_start = time.time()
  118. #we will drop the last incomplete batch when dataset size is not divisible by the batch size
  119. assert any(train_dataloader(
  120. )), "train_dataloader is null, please ensure batch size < dataset size!"
  121. for batch_id, batch in enumerate(train_dataloader()):
  122. train_reader_cost += time.time() - reader_start
  123. optimizer.clear_grad()
  124. train_start = time.time()
  125. batch_size = len(batch[0])
  126. loss, metric_list, tensor_print_dict = dy_model_class.train_forward(
  127. dy_model, metric_list, batch, config)
  128. loss.backward()
  129. optimizer.step()
  130. train_run_cost += time.time() - train_start
  131. total_samples += batch_size
  132. if batch_id % print_interval == 0:
  133. metric_str = ""
  134. for metric_id in range(len(metric_list_name)):
  135. metric_str += (
  136. metric_list_name[metric_id] +
  137. ":{:.6f}, ".format(metric_list[metric_id].accumulate())
  138. )
  139. if use_visual:
  140. log_visual.add_scalar(
  141. tag="train/" + metric_list_name[metric_id],
  142. step=step_num,
  143. value=metric_list[metric_id].accumulate())
  144. tensor_print_str = ""
  145. if tensor_print_dict is not None:
  146. for var_name, var in tensor_print_dict.items():
  147. tensor_print_str += (
  148. "{}:".format(var_name) +
  149. str(var.numpy()).strip("[]") + ",")
  150. if use_visual:
  151. log_visual.add_scalar(
  152. tag="train/" + var_name,
  153. step=step_num,
  154. value=var.numpy())
  155. logger.info(
  156. "epoch: {}, batch_id: {}, ".format(
  157. epoch_id, batch_id) + metric_str + tensor_print_str +
  158. " avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.5f} ins/s".
  159. format(train_reader_cost / print_interval, (
  160. train_reader_cost + train_run_cost) / print_interval,
  161. total_samples / print_interval, total_samples / (
  162. train_reader_cost + train_run_cost + 0.0001)))
  163. train_reader_cost = 0.0
  164. train_run_cost = 0.0
  165. total_samples = 0
  166. reader_start = time.time()
  167. step_num = step_num + 1
  168. metric_str = ""
  169. for metric_id in range(len(metric_list_name)):
  170. metric_str += (
  171. metric_list_name[metric_id] +
  172. ": {:.6f},".format(metric_list[metric_id].accumulate()))
  173. if use_auc:
  174. metric_list[metric_id].reset()
  175. tensor_print_str = ""
  176. if tensor_print_dict is not None:
  177. for var_name, var in tensor_print_dict.items():
  178. tensor_print_str += (
  179. "{}:".format(var_name) + str(var.numpy()).strip("[]") + ","
  180. )
  181. logger.info("epoch: {} done, ".format(epoch_id) + metric_str +
  182. tensor_print_str + " epoch time: {:.2f} s".format(
  183. time.time() - epoch_begin))
  184. if use_fleet:
  185. trainer_id = paddle.distributed.get_rank()
  186. if trainer_id == 0:
  187. save_model(
  188. dy_model,
  189. optimizer,
  190. model_save_path,
  191. epoch_id,
  192. prefix='rec')
  193. else:
  194. save_model(
  195. dy_model, optimizer, model_save_path, epoch_id, prefix='rec')
  196. if __name__ == '__main__':
  197. args = parse_args()
  198. main(args)