static_fl_trainer.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import print_function
  15. import os
  16. os.environ['FLAGS_enable_pir_api'] = '0'
  17. from utils.static_ps.reader_helper import get_reader, get_infer_reader, get_example_num, get_file_list, get_word_num
  18. from utils.static_ps.program_helper import get_model, get_strategy, set_dump_config
  19. from utils.static_ps.common_ps import YamlHelper, is_distributed_env
  20. import argparse
  21. import time
  22. import sys
  23. import paddle.distributed.fleet as fleet
  24. import paddle.distributed.fleet.base.role_maker as role_maker
  25. from paddle.distributed.ps.coordinator import FLClient
  26. import paddle
  27. import warnings
  28. import logging
  29. import ast
  30. import numpy as np
  31. import struct
  32. __dir__ = os.path.dirname(os.path.abspath(__file__))
  33. sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
  34. logging.basicConfig(
  35. format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
  36. logger = logging.getLogger(__name__)
  37. def parse_args():
  38. parser = argparse.ArgumentParser("PaddleRec train script")
  39. parser.add_argument(
  40. '-m',
  41. '--config_yaml',
  42. type=str,
  43. required=True,
  44. help='config file path')
  45. parser.add_argument(
  46. '-bf16',
  47. '--pure_bf16',
  48. type=ast.literal_eval,
  49. default=False,
  50. help="whether use bf16")
  51. args = parser.parse_args()
  52. args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
  53. yaml_helper = YamlHelper()
  54. config = yaml_helper.load_yaml(args.config_yaml)
  55. config["yaml_path"] = args.config_yaml
  56. config["config_abs_dir"] = args.abs_dir
  57. config["pure_bf16"] = args.pure_bf16
  58. yaml_helper.print_yaml(config)
  59. return config
  60. def bf16_to_fp32(val):
  61. return np.float32(struct.unpack('<f', struct.pack('<I', val << 16))[0])
  62. class MyFLClient(FLClient):
  63. def __init__(self):
  64. pass
  65. class Trainer(object):
  66. def __init__(self, config):
  67. self.metrics = {}
  68. self.config = config
  69. self.input_data = None
  70. self.train_dataset = None
  71. self.test_dataset = None
  72. self.model = None
  73. self.pure_bf16 = self.config['pure_bf16']
  74. self.use_cuda = int(self.config.get("runner.use_gpu"))
  75. self.place = paddle.CUDAPlace(0) if self.use_cuda else paddle.CPUPlace(
  76. )
  77. self.role = None
  78. def run(self):
  79. self.init_fleet()
  80. self.init_network()
  81. if fleet.is_server():
  82. self.run_server()
  83. elif fleet.is_worker():
  84. self.init_reader()
  85. self.run_worker()
  86. elif fleet.is_coordinator():
  87. self.run_coordinator()
  88. logger.info("Run Success, Exit.")
  89. def init_fleet(self, use_gloo=True):
  90. if use_gloo:
  91. os.environ["PADDLE_WITH_GLOO"] = "1"
  92. self.role = role_maker.PaddleCloudRoleMaker()
  93. fleet.init(self.role)
  94. else:
  95. fleet.init()
  96. def init_network(self):
  97. self.model = get_model(self.config)
  98. self.input_data = self.model.create_feeds()
  99. self.metrics = self.model.net(self.input_data)
  100. self.model.create_optimizer(get_strategy(self.config)) ## get_strategy
  101. if self.pure_bf16:
  102. self.model.optimizer.amp_init(self.place)
  103. def init_reader(self):
  104. self.train_dataset, self.train_file_list = get_reader(self.input_data,
  105. config)
  106. self.test_dataset, self.test_file_list = get_infer_reader(
  107. self.input_data, config)
  108. if self.role is not None:
  109. self.fl_client = MyFLClient()
  110. self.fl_client.set_basic_config(self.role, self.config,
  111. self.metrics)
  112. else:
  113. raise ValueError("self.role is none")
  114. self.fl_client.set_train_dataset_info(self.train_dataset,
  115. self.train_file_list)
  116. self.fl_client.set_test_dataset_info(self.test_dataset,
  117. self.test_file_list)
  118. example_nums = 0
  119. self.count_method = self.config.get("runner.example_count_method",
  120. "example")
  121. if self.count_method == "example":
  122. example_nums = get_example_num(self.train_file_list)
  123. elif self.count_method == "word":
  124. example_nums = get_word_num(self.train_file_list)
  125. else:
  126. raise ValueError(
  127. "Set static_benchmark.example_count_method for example / word for example count."
  128. )
  129. self.fl_client.set_train_example_num(example_nums)
  130. def run_coordinator(self):
  131. logger.info("Run Coordinator Begin")
  132. fleet.init_coordinator()
  133. fleet.make_fl_strategy()
  134. def run_server(self):
  135. logger.info("Run Server Begin")
  136. fleet.init_server(config.get("runner.warmup_model_path"))
  137. fleet.run_server()
  138. def run_worker(self):
  139. logger.info("Run Worker Begin")
  140. self.fl_client.run()
  141. if __name__ == "__main__":
  142. paddle.enable_static()
  143. config = parse_args()
  144. os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))
  145. trainer = Trainer(config)
  146. trainer.run()