static_ps_infer_v2.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import print_function
  15. import os
  16. os.environ['FLAGS_enable_pir_api'] = '0'
  17. from utils.static_ps.reader_helper_hdfs import get_infer_reader
  18. from utils.static_ps.program_helper import get_model, get_strategy, set_dump_config
  19. from utils.static_ps.metric_helper import set_zero, get_global_auc
  20. from utils.static_ps.common_ps import YamlHelper, is_distributed_env
  21. import argparse
  22. import time
  23. import sys
  24. import paddle.distributed.fleet as fleet
  25. import paddle.distributed.fleet.base.role_maker as role_maker
  26. import paddle
  27. from paddle.base.executor import FetchHandler
  28. import threading
  29. import warnings
  30. import logging
  31. import ast
  32. import numpy as np
  33. import struct
  34. from utils.utils_single import auc
  35. from utils.oss_client import HangZhouOSSClient
  36. import utils.compress as compress
  37. __dir__ = os.path.dirname(os.path.abspath(__file__))
  38. sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
  39. root_loger = logging.getLogger()
  40. for handler in root_loger.handlers[:]:
  41. root_loger.removeHandler(handler)
  42. logging.basicConfig(
  43. format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
  44. logger = logging.getLogger(__name__)
  45. import json
  46. class InferenceFetchHandler(object):
  47. def __init__(self, output_file, batch_size=1000):
  48. self.output_file = output_file
  49. self.batch_size = batch_size
  50. self.current_batch = []
  51. self.total_samples = 0
  52. # 添加 Paddle 需要的属性
  53. self.period_secs = 1 # 设置默认的周期时间(秒)
  54. self.done_event = threading.Event() # 添加完成事件
  55. self.terminal_event = threading.Event() # 添加终止事件
  56. # 创建输出目录(如果不存在)
  57. output_dir = os.path.dirname(output_file)
  58. if not os.path.exists(output_dir):
  59. os.makedirs(output_dir)
  60. self.var_dict = {} # 用于存储需要获取的变量
  61. # 创建或清空输出文件
  62. with open(self.output_file, 'w') as f:
  63. f.write('')
  64. def handler(self, fetch_vars):
  65. """处理每批次的推理结果"""
  66. result_dict = {}
  67. logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars))
  68. for var_name, var_value in fetch_vars.items():
  69. # 转换数据类型
  70. if isinstance(var_value, np.ndarray):
  71. result = var_value.tolist()
  72. else:
  73. result = var_value
  74. result_dict[var_name] = result
  75. self.current_batch.append(result_dict)
  76. self.total_samples += len(result_dict.get(list(result_dict.keys())[0], []))
  77. # 当累积足够的结果时,写入文件
  78. if len(self.current_batch) >= self.batch_size:
  79. self._write_batch()
  80. logger.info(f"Saved {self.total_samples} samples to {self.output_file}")
  81. def _write_batch(self):
  82. """将批次结果写入文件"""
  83. with open(self.output_file, 'a') as f:
  84. for result in self.current_batch:
  85. f.write(json.dumps(result) + '\n')
  86. self.current_batch = []
  87. def finish(self):
  88. logger.info("InferenceFetchHandler finish")
  89. """确保所有剩余结果都被保存"""
  90. if self.current_batch:
  91. self._write_batch()
  92. logger.info(f"Final save: total {self.total_samples} samples saved to {self.output_file}")
  93. self.done_event.set()
  94. def stop(self):
  95. """停止处理"""
  96. self.terminal_event.set()
  97. self.finish()
  98. def set_var_dict(self, var_dict):
  99. """设置需要获取的变量字典"""
  100. self.var_dict = var_dict
  101. def wait(self):
  102. """等待处理完成"""
  103. self.done_event.wait()
  104. def parse_args():
  105. parser = argparse.ArgumentParser("PaddleRec train script")
  106. parser.add_argument("-o", "--opt", nargs='*', type=str)
  107. parser.add_argument(
  108. '-m',
  109. '--config_yaml',
  110. type=str,
  111. required=True,
  112. help='config file path')
  113. parser.add_argument(
  114. '-bf16',
  115. '--pure_bf16',
  116. type=ast.literal_eval,
  117. default=False,
  118. help="whether use bf16")
  119. args = parser.parse_args()
  120. args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
  121. yaml_helper = YamlHelper()
  122. config = yaml_helper.load_yaml(args.config_yaml)
  123. # modify config from command
  124. if args.opt:
  125. for parameter in args.opt:
  126. parameter = parameter.strip()
  127. key, value = parameter.split("=")
  128. if type(config.get(key)) is int:
  129. value = int(value)
  130. if type(config.get(key)) is float:
  131. value = float(value)
  132. if type(config.get(key)) is bool:
  133. value = (True if value.lower() == "true" else False)
  134. config[key] = value
  135. config["yaml_path"] = args.config_yaml
  136. config["config_abs_dir"] = args.abs_dir
  137. config["pure_bf16"] = args.pure_bf16
  138. yaml_helper.print_yaml(config)
  139. return config
  140. def bf16_to_fp32(val):
  141. return np.float32(struct.unpack('<f', struct.pack('<I', val << 16))[0])
  142. class Main(object):
  143. def __init__(self, config):
  144. self.metrics = {}
  145. self.config = config
  146. self.input_data = None
  147. self.reader = None
  148. self.exe = None
  149. self.train_result_dict = {}
  150. self.train_result_dict["speed"] = []
  151. self.train_result_dict["auc"] = []
  152. self.model = None
  153. self.pure_bf16 = self.config['pure_bf16']
  154. def run(self):
  155. self.init_fleet_with_gloo()
  156. self.network()
  157. if fleet.is_server():
  158. self.run_server()
  159. elif fleet.is_worker():
  160. self.run_worker()
  161. fleet.stop_worker()
  162. self.record_result()
  163. logger.info("Run Success, Exit.")
  164. def init_fleet_with_gloo(self,use_gloo=True):
  165. if use_gloo:
  166. os.environ["PADDLE_WITH_GLOO"] = "0"
  167. role = role_maker.PaddleCloudRoleMaker(
  168. is_collective=False,
  169. init_gloo=False
  170. )
  171. fleet.init(role)
  172. else:
  173. fleet.init()
  174. def network(self):
  175. self.model = get_model(self.config)
  176. self.input_data = self.model.create_feeds(is_infer=True)
  177. self.inference_feed_var = self.input_data
  178. self.init_reader()
  179. self.metrics = self.model.net(self.inference_feed_var,is_infer=True)
  180. self.inference_target_var = self.model.inference_target_var
  181. logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
  182. self.model.create_optimizer(get_strategy(self.config),is_infer=True)
  183. def run_server(self):
  184. logger.info("Run Server Begin")
  185. fleet.init_server(config.get("runner.warmup_model_path"))
  186. fleet.run_server()
  187. def run_worker(self):
  188. logger.info("Run Worker Begin")
  189. use_cuda = int(config.get("runner.use_gpu"))
  190. use_auc = config.get("runner.infer_use_auc", False)
  191. place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
  192. self.exe = paddle.static.Executor(place)
  193. with open("./{}_worker_main_program.prototxt".format(
  194. fleet.worker_index()), 'w+') as f:
  195. f.write(str(paddle.static.default_main_program()))
  196. with open("./{}_worker_startup_program.prototxt".format(
  197. fleet.worker_index()), 'w+') as f:
  198. f.write(str(paddle.static.default_startup_program()))
  199. self.exe.run(paddle.static.default_startup_program())
  200. if self.pure_bf16:
  201. self.model.optimizer.amp_init(self.exe.place)
  202. fleet.init_worker()
  203. init_model_path = config.get("runner.infer_load_path")
  204. model_mode = config.get("runner.model_mode", 0)
  205. client = HangZhouOSSClient("art-recommend")
  206. oss_object_name = self.config.get("runner.oss_object_name", "dyp/model.tar.gz")
  207. client.get_object_to_file(oss_object_name, "model.tar.gz")
  208. compress.uncompress_tar("model.tar.gz", init_model_path)
  209. assert os.path.exists(init_model_path)
  210. #if fleet.is_first_worker():
  211. #fleet.load_inference_model(init_model_path, mode=int(model_mode))
  212. #fleet.barrier_worker()
  213. reader_type = self.config.get("runner.reader_type", "QueueDataset")
  214. epochs = int(self.config.get("runner.epochs"))
  215. sync_mode = self.config.get("runner.sync_mode")
  216. opt_info = paddle.static.default_main_program()._fleet_opt
  217. if use_auc is True:
  218. opt_info['stat_var_names'] = [
  219. self.model.stat_pos.name, self.model.stat_neg.name
  220. ]
  221. else:
  222. opt_info['stat_var_names'] = []
  223. if reader_type == "InmemoryDataset":
  224. self.reader.load_into_memory()
  225. fleet.load_inference_model(
  226. init_model_path,
  227. mode=int(model_mode))
  228. epoch_start_time = time.time()
  229. epoch = 0
  230. if sync_mode == "heter":
  231. self.heter_train_loop(epoch)
  232. elif reader_type == "QueueDataset":
  233. self.dataset_train_loop(epoch)
  234. elif reader_type == "InmemoryDataset":
  235. self.dataset_train_loop(epoch)
  236. epoch_time = time.time() - epoch_start_time
  237. logger.info(
  238. "using time {} second, ips {}/sec.".format(epoch_time, self.count_method))
  239. while True:
  240. time.sleep(300)
  241. continue;
  242. if reader_type == "InmemoryDataset":
  243. self.reader.release_memory()
  244. def init_reader(self):
  245. if fleet.is_server():
  246. return
  247. self.config["runner.reader_type"] = self.config.get(
  248. "runner.reader_type", "QueueDataset")
  249. self.reader, self.file_list = get_infer_reader(self.input_data, config)
  250. self.example_nums = 0
  251. self.count_method = self.config.get("runner.example_count_method",
  252. "example")
  253. def dataset_train_loop(self, epoch=0):
  254. logger.info("Epoch: {}, Running Dataset Begin.".format(epoch))
  255. fetch_info = [
  256. "Epoch {} Var {}".format(epoch, var_name)
  257. for var_name in self.metrics
  258. ]
  259. fetch_vars = [var for _, var in self.metrics.items()]
  260. print_step = int(config.get("runner.print_interval"))
  261. debug = config.get("runner.dataset_debug", False)
  262. if config.get("runner.need_dump"):
  263. debug = True
  264. dump_fields_path = "{}/{}".format(
  265. config.get("runner.dump_fields_path"), epoch)
  266. set_dump_config(paddle.static.default_main_program(), {
  267. "dump_fields_path": dump_fields_path,
  268. "dump_fields": config.get("runner.dump_fields")
  269. })
  270. # 设置输出文件路径
  271. output_dir = config.get("runner.inference_output_dir", "inference_results")
  272. output_file = os.path.join(output_dir, f"epoch_{epoch}_results.jsonl")
  273. # 创建处理器实例
  274. fetch_handler = FetchHandler(var_dict=self.metrics, period_secs=1)
  275. print(paddle.static.default_main_program()._fleet_opt)
  276. self.exe.infer_from_dataset(
  277. program=paddle.static.default_main_program(),
  278. dataset=self.reader,
  279. fetch_list=fetch_vars,
  280. fetch_info=fetch_info,
  281. print_period=print_step,
  282. debug=debug,
  283. fetch_handler=fetch_handler)
  284. def heter_train_loop(self, epoch):
  285. logger.info(
  286. "Epoch: {}, Running Begin. Check running metrics at heter_log".
  287. format(epoch))
  288. reader_type = self.config.get("runner.reader_type")
  289. if reader_type == "QueueDataset":
  290. self.exe.infer_from_dataset(
  291. program=paddle.static.default_main_program(),
  292. dataset=self.reader,
  293. debug=config.get("runner.dataset_debug"))
  294. elif reader_type == "DataLoader":
  295. batch_id = 0
  296. train_run_cost = 0.0
  297. total_examples = 0
  298. self.reader.start()
  299. while True:
  300. try:
  301. train_start = time.time()
  302. # --------------------------------------------------- #
  303. self.exe.run(program=paddle.static.default_main_program())
  304. # --------------------------------------------------- #
  305. train_run_cost += time.time() - train_start
  306. total_examples += self.config.get("runner.batch_size")
  307. batch_id += 1
  308. print_step = int(config.get("runner.print_period"))
  309. if batch_id % print_step == 0:
  310. profiler_string = ""
  311. profiler_string += "avg_batch_cost: {} sec, ".format(
  312. format((train_run_cost) / print_step, '.5f'))
  313. profiler_string += "avg_samples: {}, ".format(
  314. format(total_examples / print_step, '.5f'))
  315. profiler_string += "ips: {} {}/sec ".format(
  316. format(total_examples / (train_run_cost), '.5f'),
  317. self.count_method)
  318. logger.info("Epoch: {}, Batch: {}, {}".format(
  319. epoch, batch_id, profiler_string))
  320. train_run_cost = 0.0
  321. total_examples = 0
  322. except paddle.core.EOFException:
  323. self.reader.reset()
  324. break
  325. def record_result(self):
  326. logger.info("train_result_dict: {}".format(self.train_result_dict))
  327. with open("./train_result_dict.txt", 'w+') as f:
  328. f.write(str(self.train_result_dict))
  329. if __name__ == "__main__":
  330. paddle.enable_static()
  331. config = parse_args()
  332. os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))
  333. benchmark_main = Main(config)
  334. benchmark_main.run()