static_ps_trainer_v2.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import print_function
  15. import os
  16. os.environ['FLAGS_enable_pir_api'] = '0'
  17. from utils.static_ps.reader_helper_hdfs import get_reader
  18. from utils.static_ps.program_helper import get_model, get_strategy, set_dump_config
  19. from utils.static_ps.metric_helper import set_zero, get_global_auc
  20. from utils.static_ps.common_ps import YamlHelper, is_distributed_env
  21. import argparse
  22. import time
  23. import sys
  24. import paddle.distributed.fleet as fleet
  25. import paddle.distributed.fleet.base.role_maker as role_maker
  26. import paddle
  27. import paddle.base.core as core
  28. import warnings
  29. import logging
  30. import ast
  31. import numpy as np
  32. import struct
  33. from utils.utils_single import auc
  34. from utils.oss_client import HangZhouOSSClient
  35. import utils.compress as compress
  36. sys.path.append(os.path.dirname(os.path.abspath("")) + os.sep + "lib")
  37. print(os.path.dirname(os.path.abspath("")) + os.sep + "lib1")
  38. import brpc_flags
  39. print(os.path.dirname(os.path.abspath("")) + os.sep + "lib2")
  40. __dir__ = os.path.dirname(os.path.abspath(__file__))
  41. sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
  42. root_loger = logging.getLogger()
  43. for handler in root_loger.handlers[:]:
  44. root_loger.removeHandler(handler)
  45. logging.basicConfig(
  46. format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
  47. logger = logging.getLogger(__name__)
  48. def parse_args():
  49. parser = argparse.ArgumentParser("PaddleRec train script")
  50. parser.add_argument("-o", "--opt", nargs='*', type=str)
  51. parser.add_argument(
  52. '-m',
  53. '--config_yaml',
  54. type=str,
  55. required=True,
  56. help='config file path')
  57. parser.add_argument(
  58. '-bf16',
  59. '--pure_bf16',
  60. type=ast.literal_eval,
  61. default=False,
  62. help="whether use bf16")
  63. args = parser.parse_args()
  64. args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
  65. yaml_helper = YamlHelper()
  66. config = yaml_helper.load_yaml(args.config_yaml)
  67. # modify config from command
  68. if args.opt:
  69. for parameter in args.opt:
  70. parameter = parameter.strip()
  71. key, value = parameter.split("=")
  72. if type(config.get(key)) is int:
  73. value = int(value)
  74. if type(config.get(key)) is float:
  75. value = float(value)
  76. if type(config.get(key)) is bool:
  77. value = (True if value.lower() == "true" else False)
  78. config[key] = value
  79. config["yaml_path"] = args.config_yaml
  80. config["config_abs_dir"] = args.abs_dir
  81. config["pure_bf16"] = args.pure_bf16
  82. yaml_helper.print_yaml(config)
  83. return config
  84. def bf16_to_fp32(val):
  85. return np.float32(struct.unpack('<f', struct.pack('<I', val << 16))[0])
  86. class Main(object):
  87. def __init__(self, config):
  88. self.metrics = {}
  89. self.config = config
  90. self.input_data = None
  91. self.reader = None
  92. self.exe = None
  93. self.train_result_dict = {}
  94. self.train_result_dict["speed"] = []
  95. self.train_result_dict["auc"] = []
  96. self.model = None
  97. self.pure_bf16 = self.config['pure_bf16']
  98. def run(self):
  99. self.init_fleet_with_gloo()
  100. self.network()
  101. if fleet.is_server():
  102. self.run_server()
  103. elif fleet.is_worker():
  104. self.run_worker()
  105. fleet.stop_worker()
  106. self.record_result()
  107. logger.info("Run Success, Exit.")
  108. def init_fleet_with_gloo(use_gloo=True):
  109. fleet_config = {
  110. "max_body_size": 256 * 1024 * 1024, # 设置为256MB
  111. }
  112. if use_gloo:
  113. os.environ["PADDLE_WITH_GLOO"] = "0"
  114. role = role_maker.PaddleCloudRoleMaker(
  115. is_collective=False,
  116. init_gloo=False
  117. )
  118. fleet.init(role)
  119. #logger.info("worker_index: %s", fleet.worker_index())
  120. #logger.info("is_first_worker: %s", fleet.is_first_worker())
  121. #logger.info("worker_num: %s", fleet.worker_num())
  122. #logger.info("is_distributed: %s", fleet.is_distributed())
  123. #logger.info("mode: %s", fleet.mode)
  124. else:
  125. # 在Fleet初始化配置中添加以下参数
  126. fleet.init()
  127. #fleet.set_fleet_desc(fleet_config)
  128. def network(self):
  129. self.model = get_model(self.config)
  130. self.input_data = self.model.create_feeds()
  131. self.inference_feed_var = self.model.create_feeds(is_infer=True)
  132. self.init_reader()
  133. self.metrics = self.model.net(self.input_data)
  134. self.inference_target_var = self.model.inference_target_var
  135. logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
  136. self.model.create_optimizer(get_strategy(self.config))
  137. def run_server(self):
  138. logger.info("Run Server Begin")
  139. fleet.init_server(config.get("runner.warmup_model_path"))
  140. fleet.run_server()
  141. def run_worker(self):
  142. logger.info("Run Worker Begin")
  143. use_cuda = int(config.get("runner.use_gpu"))
  144. use_auc = config.get("runner.use_auc", False)
  145. place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
  146. self.exe = paddle.static.Executor(place)
  147. with open("./{}_worker_main_program.prototxt".format(
  148. fleet.worker_index()), 'w+') as f:
  149. f.write(str(paddle.static.default_main_program()))
  150. with open("./{}_worker_startup_program.prototxt".format(
  151. fleet.worker_index()), 'w+') as f:
  152. f.write(str(paddle.static.default_startup_program()))
  153. self.exe.run(paddle.static.default_startup_program())
  154. if self.pure_bf16:
  155. self.model.optimizer.amp_init(self.exe.place)
  156. fleet.init_worker()
  157. save_model_path = self.config.get("runner.model_save_path")
  158. if save_model_path and (not os.path.exists(save_model_path)):
  159. os.makedirs(save_model_path)
  160. reader_type = self.config.get("runner.reader_type", "QueueDataset")
  161. epochs = int(self.config.get("runner.epochs"))
  162. sync_mode = self.config.get("runner.sync_mode")
  163. opt_info = paddle.static.default_main_program()._fleet_opt
  164. if use_auc is True:
  165. opt_info['stat_var_names'] = [
  166. self.model.stat_pos.name, self.model.stat_neg.name
  167. ]
  168. else:
  169. opt_info['stat_var_names'] = []
  170. if reader_type == "InmemoryDataset":
  171. self.reader.load_into_memory()
  172. for epoch in range(epochs):
  173. epoch_start_time = time.time()
  174. if sync_mode == "heter":
  175. self.heter_train_loop(epoch)
  176. elif reader_type == "QueueDataset":
  177. self.dataset_train_loop(epoch)
  178. elif reader_type == "InmemoryDataset":
  179. self.dataset_train_loop(epoch)
  180. epoch_time = time.time() - epoch_start_time
  181. if use_auc is True:
  182. global_auc = get_global_auc(paddle.static.global_scope(),
  183. self.model.stat_pos.name,
  184. self.model.stat_neg.name)
  185. self.train_result_dict["auc"].append(global_auc)
  186. set_zero(self.model.stat_pos.name,
  187. paddle.static.global_scope())
  188. set_zero(self.model.stat_neg.name,
  189. paddle.static.global_scope())
  190. set_zero(self.model.batch_stat_pos.name,
  191. paddle.static.global_scope())
  192. set_zero(self.model.batch_stat_neg.name,
  193. paddle.static.global_scope())
  194. logger.info(
  195. "Epoch: {}, using time: {} second, ips: {}/sec. auc: {}".
  196. format(epoch, epoch_time, self.count_method,
  197. global_auc))
  198. else:
  199. logger.info(
  200. "Epoch: {}, using time {} second, ips {}/sec.".format(
  201. epoch, epoch_time, self.count_method))
  202. model_dir = "{}/{}".format(save_model_path, epoch)
  203. if is_distributed_env():
  204. fleet.save_inference_model(
  205. self.exe, model_dir,
  206. [feed.name for feed in self.inference_feed_var],
  207. self.inference_target_var)
  208. else:
  209. paddle.static.save_inference_model(
  210. model_dir,
  211. [feed.name for feed in self.inference_feed_var],
  212. [self.inference_target_var], self.exe)
  213. compress.compress_tar(model_dir, "test")
  214. client = HangZhouOSSClient("art-recommend")
  215. client.put_object_from_file("dyp/test.tar.gz", "test.tar.gz")
  216. if reader_type == "InmemoryDataset":
  217. self.reader.release_memory()
  218. def init_reader(self):
  219. if fleet.is_server():
  220. return
  221. self.config["runner.reader_type"] = self.config.get(
  222. "runner.reader_type", "QueueDataset")
  223. self.reader, self.file_list = get_reader(self.input_data, config)
  224. self.example_nums = 0
  225. self.count_method = self.config.get("runner.example_count_method",
  226. "example")
  227. def dataset_train_loop(self, epoch):
  228. logger.info("Epoch: {}, Running Dataset Begin.".format(epoch))
  229. fetch_info = [
  230. "Epoch {} Var {}".format(epoch, var_name)
  231. for var_name in self.metrics
  232. ]
  233. logger.info("Epoch: {},{}, Running Dataset Begin.".format(epoch,self.metrics))
  234. fetch_vars = [var for _, var in self.metrics.items()]
  235. print_step = int(config.get("runner.print_interval"))
  236. debug = config.get("runner.dataset_debug", False)
  237. if config.get("runner.need_dump"):
  238. debug = True
  239. dump_fields_path = "{}/{}".format(
  240. config.get("runner.dump_fields_path"), epoch)
  241. set_dump_config(paddle.static.default_main_program(), {
  242. "dump_fields_path": dump_fields_path,
  243. "dump_fields": config.get("runner.dump_fields")
  244. })
  245. # logger.info(paddle.static.default_main_program()._fleet_opt)
  246. self.exe.train_from_dataset(
  247. program=paddle.static.default_main_program(),
  248. dataset=self.reader,
  249. fetch_list=fetch_vars,
  250. fetch_info=fetch_info,
  251. print_period=print_step,
  252. debug=debug)
  253. def heter_train_loop(self, epoch):
  254. logger.info(
  255. "Epoch: {}, Running Begin. Check running metrics at heter_log".
  256. format(epoch))
  257. reader_type = self.config.get("runner.reader_type")
  258. if reader_type == "QueueDataset":
  259. self.exe.train_from_dataset(
  260. program=paddle.static.default_main_program(),
  261. dataset=self.reader,
  262. debug=config.get("runner.dataset_debug"))
  263. elif reader_type == "DataLoader":
  264. batch_id = 0
  265. train_run_cost = 0.0
  266. total_examples = 0
  267. self.reader.start()
  268. while True:
  269. try:
  270. train_start = time.time()
  271. # --------------------------------------------------- #
  272. self.exe.run(program=paddle.static.default_main_program())
  273. # --------------------------------------------------- #
  274. train_run_cost += time.time() - train_start
  275. total_examples += self.config.get("runner.batch_size")
  276. batch_id += 1
  277. print_step = int(config.get("runner.print_period"))
  278. if batch_id % print_step == 0:
  279. profiler_string = ""
  280. profiler_string += "avg_batch_cost: {} sec, ".format(
  281. format((train_run_cost) / print_step, '.5f'))
  282. profiler_string += "avg_samples: {}, ".format(
  283. format(total_examples / print_step, '.5f'))
  284. profiler_string += "ips: {} {}/sec ".format(
  285. format(total_examples / (train_run_cost), '.5f'),
  286. self.count_method)
  287. logger.info("Epoch: {}, Batch: {}, {}".format(
  288. epoch, batch_id, profiler_string))
  289. train_run_cost = 0.0
  290. total_examples = 0
  291. except paddle.core.EOFException:
  292. self.reader.reset()
  293. break
  294. def record_result(self):
  295. logger.info("train_result_dict: {}".format(self.train_result_dict))
  296. with open("./train_result_dict.txt", 'w+') as f:
  297. f.write(str(self.train_result_dict))
  298. if __name__ == "__main__":
  299. print("get_max_body_size1")
  300. print(brpc_flags.get_max_body_size())
  301. brpc_flags.set_max_body_size(123456789)
  302. print("get_max_body_size2")
  303. print(brpc_flags.get_max_body_size())
  304. paddle.enable_static()
  305. # read_env_flags = [
  306. # key[len("FLAGS_") :]
  307. # for key in core.globals().keys()
  308. # if key.startswith("FLAGS_")
  309. # ]
  310. # def remove_flag_if_exists(name):
  311. # if name in read_env_flags:
  312. # read_env_flags.remove(name)
  313. # # core.init_gflags(["--tryfromenv=,max_body_size" + ",".join(read_env_flags)])
  314. # print('global_flags: ' + str(list(paddle.base.framework._global_flags().keys())))
  315. config = parse_args()
  316. os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))
  317. benchmark_main = Main(config)
  318. benchmark_main.run()