static_ps_infer_v2.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import print_function
  15. import os
  16. os.environ['FLAGS_enable_pir_api'] = '0'
  17. from utils.static_ps.reader_helper_hdfs import get_infer_reader
  18. from utils.static_ps.program_helper import get_model, get_strategy, set_dump_config
  19. from utils.static_ps.metric_helper import set_zero, get_global_auc
  20. from utils.static_ps.common_ps import YamlHelper, is_distributed_env
  21. import argparse
  22. import time
  23. import sys
  24. import paddle.distributed.fleet as fleet
  25. import paddle.distributed.fleet.base.role_maker as role_maker
  26. import paddle
  27. from paddle.base.executor import FetchHandler
  28. import threading
  29. import warnings
  30. import logging
  31. import ast
  32. import numpy as np
  33. import struct
  34. from utils.utils_single import auc
  35. from utils.oss_client import HangZhouOSSClient
  36. import utils.compress as compress
  37. __dir__ = os.path.dirname(os.path.abspath(__file__))
  38. sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
  39. root_loger = logging.getLogger()
  40. for handler in root_loger.handlers[:]:
  41. root_loger.removeHandler(handler)
  42. logging.basicConfig(
  43. format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
  44. logger = logging.getLogger(__name__)
  45. import json
  46. class InferenceFetchHandler(FetchHandler):
  47. def __init__(self, var_dict, output_file, batch_size=1000):
  48. super().__init__(var_dict=var_dict, period_secs=1)
  49. self.output_file = output_file
  50. self.batch_size = batch_size
  51. self.current_batch = []
  52. self.total_samples = 0
  53. # 创建输出目录(如果不存在)
  54. output_dir = os.path.dirname(output_file)
  55. if not os.path.exists(output_dir):
  56. os.makedirs(output_dir)
  57. # 创建或清空输出文件
  58. with open(self.output_file, 'w') as f:
  59. f.write('')
  60. def handler(self, fetch_vars):
  61. """处理每批次的推理结果"""
  62. result_dict = {}
  63. logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars))
  64. for var_name, var_value in fetch_vars.items():
  65. # 转换数据类型
  66. if isinstance(var_value, np.ndarray):
  67. result = var_value.tolist()
  68. else:
  69. result = var_value
  70. result_dict[var_name] = result
  71. self.current_batch.append(result_dict)
  72. self.total_samples += len(result_dict.get(list(result_dict.keys())[0], []))
  73. # 当累积足够的结果时,写入文件
  74. if len(self.current_batch) >= self.batch_size:
  75. self._write_batch()
  76. logger.info(f"Saved {self.total_samples} samples to {self.output_file}")
  77. def _write_batch(self):
  78. """将批次结果写入文件"""
  79. with open(self.output_file, 'a') as f:
  80. for result in self.current_batch:
  81. f.write(json.dumps(result) + '\n')
  82. self.current_batch = []
  83. def finish(self):
  84. logger.info("InferenceFetchHandler finish")
  85. """确保所有剩余结果都被保存"""
  86. if self.current_batch:
  87. self._write_batch()
  88. logger.info(f"Final save: total {self.total_samples} samples saved to {self.output_file}")
  89. self.done_event.set()
  90. def parse_args():
  91. parser = argparse.ArgumentParser("PaddleRec train script")
  92. parser.add_argument("-o", "--opt", nargs='*', type=str)
  93. parser.add_argument(
  94. '-m',
  95. '--config_yaml',
  96. type=str,
  97. required=True,
  98. help='config file path')
  99. parser.add_argument(
  100. '-bf16',
  101. '--pure_bf16',
  102. type=ast.literal_eval,
  103. default=False,
  104. help="whether use bf16")
  105. args = parser.parse_args()
  106. args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
  107. yaml_helper = YamlHelper()
  108. config = yaml_helper.load_yaml(args.config_yaml)
  109. # modify config from command
  110. if args.opt:
  111. for parameter in args.opt:
  112. parameter = parameter.strip()
  113. key, value = parameter.split("=")
  114. if type(config.get(key)) is int:
  115. value = int(value)
  116. if type(config.get(key)) is float:
  117. value = float(value)
  118. if type(config.get(key)) is bool:
  119. value = (True if value.lower() == "true" else False)
  120. config[key] = value
  121. config["yaml_path"] = args.config_yaml
  122. config["config_abs_dir"] = args.abs_dir
  123. config["pure_bf16"] = args.pure_bf16
  124. yaml_helper.print_yaml(config)
  125. return config
  126. def bf16_to_fp32(val):
  127. return np.float32(struct.unpack('<f', struct.pack('<I', val << 16))[0])
  128. class Main(object):
  129. def __init__(self, config):
  130. self.metrics = {}
  131. self.config = config
  132. self.input_data = None
  133. self.reader = None
  134. self.exe = None
  135. self.train_result_dict = {}
  136. self.train_result_dict["speed"] = []
  137. self.train_result_dict["auc"] = []
  138. self.model = None
  139. self.pure_bf16 = self.config['pure_bf16']
  140. def run(self):
  141. self.init_fleet_with_gloo()
  142. self.network()
  143. if fleet.is_server():
  144. self.run_server()
  145. elif fleet.is_worker():
  146. self.run_worker()
  147. fleet.stop_worker()
  148. self.record_result()
  149. logger.info("Run Success, Exit.")
  150. def init_fleet_with_gloo(self,use_gloo=True):
  151. if use_gloo:
  152. os.environ["PADDLE_WITH_GLOO"] = "0"
  153. role = role_maker.PaddleCloudRoleMaker(
  154. is_collective=False,
  155. init_gloo=False
  156. )
  157. fleet.init(role)
  158. else:
  159. fleet.init()
  160. def network(self):
  161. self.model = get_model(self.config)
  162. self.input_data = self.model.create_feeds(is_infer=True)
  163. self.inference_feed_var = self.input_data
  164. self.init_reader()
  165. self.metrics = self.model.net(self.inference_feed_var,is_infer=True)
  166. self.inference_target_var = self.model.inference_target_var
  167. logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
  168. self.model.create_optimizer(get_strategy(self.config),is_infer=True)
  169. def run_server(self):
  170. logger.info("Run Server Begin")
  171. fleet.init_server(config.get("runner.warmup_model_path"))
  172. fleet.run_server()
  173. def run_worker(self):
  174. logger.info("Run Worker Begin")
  175. use_cuda = int(config.get("runner.use_gpu"))
  176. use_auc = config.get("runner.infer_use_auc", False)
  177. place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
  178. self.exe = paddle.static.Executor(place)
  179. with open("./{}_worker_main_program.prototxt".format(
  180. fleet.worker_index()), 'w+') as f:
  181. f.write(str(paddle.static.default_main_program()))
  182. with open("./{}_worker_startup_program.prototxt".format(
  183. fleet.worker_index()), 'w+') as f:
  184. f.write(str(paddle.static.default_startup_program()))
  185. self.exe.run(paddle.static.default_startup_program())
  186. if self.pure_bf16:
  187. self.model.optimizer.amp_init(self.exe.place)
  188. fleet.init_worker()
  189. init_model_path = config.get("runner.infer_load_path")
  190. model_mode = config.get("runner.model_mode", 0)
  191. client = HangZhouOSSClient("art-recommend")
  192. oss_object_name = self.config.get("runner.oss_object_name", "dyp/model.tar.gz")
  193. client.get_object_to_file(oss_object_name, "model.tar.gz")
  194. compress.uncompress_tar("model.tar.gz", init_model_path)
  195. assert os.path.exists(init_model_path)
  196. #if fleet.is_first_worker():
  197. #fleet.load_inference_model(init_model_path, mode=int(model_mode))
  198. #fleet.barrier_worker()
  199. reader_type = self.config.get("runner.reader_type", "QueueDataset")
  200. epochs = int(self.config.get("runner.epochs"))
  201. sync_mode = self.config.get("runner.sync_mode")
  202. opt_info = paddle.static.default_main_program()._fleet_opt
  203. if use_auc is True:
  204. opt_info['stat_var_names'] = [
  205. self.model.stat_pos.name, self.model.stat_neg.name
  206. ]
  207. else:
  208. opt_info['stat_var_names'] = []
  209. if reader_type == "InmemoryDataset":
  210. self.reader.load_into_memory()
  211. fleet.load_inference_model(
  212. init_model_path,
  213. mode=int(model_mode))
  214. epoch_start_time = time.time()
  215. epoch = 0
  216. if sync_mode == "heter":
  217. self.heter_train_loop(epoch)
  218. elif reader_type == "QueueDataset":
  219. self.dataset_train_loop(epoch)
  220. elif reader_type == "InmemoryDataset":
  221. self.dataset_train_loop(epoch)
  222. epoch_time = time.time() - epoch_start_time
  223. logger.info(
  224. "using time {} second, ips {}/sec.".format(epoch_time, self.count_method))
  225. while True:
  226. time.sleep(300)
  227. continue;
  228. if reader_type == "InmemoryDataset":
  229. self.reader.release_memory()
  230. def init_reader(self):
  231. if fleet.is_server():
  232. return
  233. self.config["runner.reader_type"] = self.config.get(
  234. "runner.reader_type", "QueueDataset")
  235. self.reader, self.file_list = get_infer_reader(self.input_data, config)
  236. self.example_nums = 0
  237. self.count_method = self.config.get("runner.example_count_method",
  238. "example")
  239. def dataset_train_loop(self, epoch=0):
  240. logger.info("Epoch: {}, Running Dataset Begin.".format(epoch))
  241. fetch_info = [
  242. "Epoch {} Var {}".format(epoch, var_name)
  243. for var_name in self.metrics
  244. ]
  245. fetch_vars = [var for _, var in self.metrics.items()]
  246. print_step = int(config.get("runner.print_interval"))
  247. debug = config.get("runner.dataset_debug", False)
  248. if config.get("runner.need_dump"):
  249. debug = True
  250. dump_fields_path = "{}/{}".format(
  251. config.get("runner.dump_fields_path"), epoch)
  252. set_dump_config(paddle.static.default_main_program(), {
  253. "dump_fields_path": dump_fields_path,
  254. "dump_fields": config.get("runner.dump_fields")
  255. })
  256. # 设置输出文件路径
  257. output_dir = config.get("runner.inference_output_dir", "inference_results")
  258. output_file = os.path.join(output_dir, f"epoch_{epoch}_results.jsonl")
  259. # 创建处理器实例
  260. fetch_handler = InferenceFetchHandler(var_dict = self.metrics, output_file =output_file)
  261. # fetch_handler.set_var_dict(self.metrics)
  262. print(paddle.static.default_main_program()._fleet_opt)
  263. self.exe.infer_from_dataset(
  264. program=paddle.static.default_main_program(),
  265. dataset=self.reader,
  266. fetch_list=fetch_vars,
  267. fetch_info=fetch_info,
  268. print_period=print_step,
  269. debug=debug,
  270. fetch_handler=fetch_handler)
  271. def heter_train_loop(self, epoch):
  272. logger.info(
  273. "Epoch: {}, Running Begin. Check running metrics at heter_log".
  274. format(epoch))
  275. reader_type = self.config.get("runner.reader_type")
  276. if reader_type == "QueueDataset":
  277. self.exe.infer_from_dataset(
  278. program=paddle.static.default_main_program(),
  279. dataset=self.reader,
  280. debug=config.get("runner.dataset_debug"))
  281. elif reader_type == "DataLoader":
  282. batch_id = 0
  283. train_run_cost = 0.0
  284. total_examples = 0
  285. self.reader.start()
  286. while True:
  287. try:
  288. train_start = time.time()
  289. # --------------------------------------------------- #
  290. self.exe.run(program=paddle.static.default_main_program())
  291. # --------------------------------------------------- #
  292. train_run_cost += time.time() - train_start
  293. total_examples += self.config.get("runner.batch_size")
  294. batch_id += 1
  295. print_step = int(config.get("runner.print_period"))
  296. if batch_id % print_step == 0:
  297. profiler_string = ""
  298. profiler_string += "avg_batch_cost: {} sec, ".format(
  299. format((train_run_cost) / print_step, '.5f'))
  300. profiler_string += "avg_samples: {}, ".format(
  301. format(total_examples / print_step, '.5f'))
  302. profiler_string += "ips: {} {}/sec ".format(
  303. format(total_examples / (train_run_cost), '.5f'),
  304. self.count_method)
  305. logger.info("Epoch: {}, Batch: {}, {}".format(
  306. epoch, batch_id, profiler_string))
  307. train_run_cost = 0.0
  308. total_examples = 0
  309. except paddle.core.EOFException:
  310. self.reader.reset()
  311. break
  312. def record_result(self):
  313. logger.info("train_result_dict: {}".format(self.train_result_dict))
  314. with open("./train_result_dict.txt", 'w+') as f:
  315. f.write(str(self.train_result_dict))
  316. if __name__ == "__main__":
  317. paddle.enable_static()
  318. config = parse_args()
  319. os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))
  320. benchmark_main = Main(config)
  321. benchmark_main.run()