static_ps_infer_v2.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import print_function
  15. import os
  16. os.environ['FLAGS_enable_pir_api'] = '0'
  17. from utils.static_ps.reader_helper import get_infer_reader
  18. from utils.static_ps.program_helper import get_model, get_strategy, set_dump_config
  19. from utils.static_ps.metric_helper import set_zero, get_global_auc
  20. from utils.static_ps.common_ps import YamlHelper, is_distributed_env
  21. import argparse
  22. import time
  23. import sys
  24. import paddle.distributed.fleet as fleet
  25. import paddle.distributed.fleet.base.role_maker as role_maker
  26. import paddle
  27. import threading
  28. import warnings
  29. import logging
  30. import ast
  31. import numpy as np
  32. import struct
  33. from utils.utils_single import auc
  34. from utils.oss_client import HangZhouOSSClient
  35. import utils.compress as compress
  36. __dir__ = os.path.dirname(os.path.abspath(__file__))
  37. sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
  38. root_loger = logging.getLogger()
  39. for handler in root_loger.handlers[:]:
  40. root_loger.removeHandler(handler)
  41. logging.basicConfig(
  42. format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
  43. logger = logging.getLogger(__name__)
  44. import json
  45. class InferenceFetchHandler(object):
  46. def __init__(self, output_file, batch_size=1000):
  47. self.output_file = output_file
  48. self.batch_size = batch_size
  49. self.current_batch = []
  50. self.total_samples = 0
  51. # 添加 Paddle 需要的属性
  52. self.period_secs = 60 # 设置默认的周期时间(秒)
  53. self.done_event = threading.Event() # 添加完成事件
  54. self.terminal_event = threading.Event() # 添加终止事件
  55. # 创建输出目录(如果不存在)
  56. output_dir = os.path.dirname(output_file)
  57. if not os.path.exists(output_dir):
  58. os.makedirs(output_dir)
  59. self.var_dict = {} # 用于存储需要获取的变量
  60. # 创建或清空输出文件
  61. with open(self.output_file, 'w') as f:
  62. f.write('')
  63. def handler(self, fetch_vars):
  64. """处理每批次的推理结果"""
  65. result_dict = {}
  66. for var_name, var_value in fetch_vars.items():
  67. # 转换数据类型
  68. if isinstance(var_value, np.ndarray):
  69. result = var_value.tolist()
  70. else:
  71. result = var_value
  72. result_dict[var_name] = result
  73. self.current_batch.append(result_dict)
  74. self.total_samples += len(result_dict.get(list(result_dict.keys())[0], []))
  75. # 当累积足够的结果时,写入文件
  76. if len(self.current_batch) >= self.batch_size:
  77. self._write_batch()
  78. logger.info(f"Saved {self.total_samples} samples to {self.output_file}")
  79. def _write_batch(self):
  80. """将批次结果写入文件"""
  81. with open(self.output_file, 'a') as f:
  82. for result in self.current_batch:
  83. f.write(json.dumps(result) + '\n')
  84. self.current_batch = []
  85. def finish(self):
  86. """确保所有剩余结果都被保存"""
  87. if self.current_batch:
  88. self._write_batch()
  89. logger.info(f"Final save: total {self.total_samples} samples saved to {self.output_file}")
  90. self.done_event.set()
  91. def stop(self):
  92. """停止处理"""
  93. self.terminal_event.set()
  94. self.finish()
  95. def set_var_dict(self, var_dict):
  96. """设置需要获取的变量字典"""
  97. self.var_dict = var_dict
  98. def wait(self):
  99. """等待处理完成"""
  100. self.done_event.wait()
  101. def parse_args():
  102. parser = argparse.ArgumentParser("PaddleRec train script")
  103. parser.add_argument("-o", "--opt", nargs='*', type=str)
  104. parser.add_argument(
  105. '-m',
  106. '--config_yaml',
  107. type=str,
  108. required=True,
  109. help='config file path')
  110. parser.add_argument(
  111. '-bf16',
  112. '--pure_bf16',
  113. type=ast.literal_eval,
  114. default=False,
  115. help="whether use bf16")
  116. args = parser.parse_args()
  117. args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
  118. yaml_helper = YamlHelper()
  119. config = yaml_helper.load_yaml(args.config_yaml)
  120. # modify config from command
  121. if args.opt:
  122. for parameter in args.opt:
  123. parameter = parameter.strip()
  124. key, value = parameter.split("=")
  125. if type(config.get(key)) is int:
  126. value = int(value)
  127. if type(config.get(key)) is float:
  128. value = float(value)
  129. if type(config.get(key)) is bool:
  130. value = (True if value.lower() == "true" else False)
  131. config[key] = value
  132. config["yaml_path"] = args.config_yaml
  133. config["config_abs_dir"] = args.abs_dir
  134. config["pure_bf16"] = args.pure_bf16
  135. yaml_helper.print_yaml(config)
  136. return config
  137. def bf16_to_fp32(val):
  138. return np.float32(struct.unpack('<f', struct.pack('<I', val << 16))[0])
  139. class Main(object):
  140. def __init__(self, config):
  141. self.metrics = {}
  142. self.config = config
  143. self.input_data = None
  144. self.reader = None
  145. self.exe = None
  146. self.train_result_dict = {}
  147. self.train_result_dict["speed"] = []
  148. self.train_result_dict["auc"] = []
  149. self.model = None
  150. self.pure_bf16 = self.config['pure_bf16']
  151. def run(self):
  152. self.init_fleet_with_gloo()
  153. self.network()
  154. if fleet.is_server():
  155. self.run_server()
  156. elif fleet.is_worker():
  157. self.run_worker()
  158. fleet.stop_worker()
  159. self.record_result()
  160. logger.info("Run Success, Exit.")
  161. def init_fleet_with_gloo(self,use_gloo=True):
  162. if use_gloo:
  163. os.environ["PADDLE_WITH_GLOO"] = "0"
  164. role = role_maker.PaddleCloudRoleMaker(
  165. is_collective=False,
  166. init_gloo=False
  167. )
  168. fleet.init(role)
  169. else:
  170. fleet.init()
  171. def network(self):
  172. self.model = get_model(self.config)
  173. self.input_data = self.model.create_feeds(is_infer=True)
  174. self.inference_feed_var = self.input_data
  175. self.init_reader()
  176. self.metrics = self.model.net(self.inference_feed_var,is_infer=True)
  177. self.inference_target_var = self.model.inference_target_var
  178. logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
  179. self.model.create_optimizer(get_strategy(self.config),is_infer=True)
  180. def run_server(self):
  181. logger.info("Run Server Begin")
  182. fleet.init_server(config.get("runner.warmup_model_path"))
  183. fleet.run_server()
  184. def run_worker(self):
  185. logger.info("Run Worker Begin")
  186. use_cuda = int(config.get("runner.use_gpu"))
  187. use_auc = config.get("runner.infer_use_auc", False)
  188. place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
  189. self.exe = paddle.static.Executor(place)
  190. with open("./{}_worker_main_program.prototxt".format(
  191. fleet.worker_index()), 'w+') as f:
  192. f.write(str(paddle.static.default_main_program()))
  193. with open("./{}_worker_startup_program.prototxt".format(
  194. fleet.worker_index()), 'w+') as f:
  195. f.write(str(paddle.static.default_startup_program()))
  196. self.exe.run(paddle.static.default_startup_program())
  197. if self.pure_bf16:
  198. self.model.optimizer.amp_init(self.exe.place)
  199. fleet.init_worker()
  200. init_model_path = config.get("runner.infer_load_path")
  201. model_mode = config.get("runner.model_mode", 0)
  202. client = HangZhouOSSClient("art-recommend")
  203. client.get_object_to_file("lqc/64.tar.gz", "64.tar.gz")
  204. compress.uncompress_tar("64.tar.gz", init_model_path)
  205. assert os.path.exists(init_model_path)
  206. #if fleet.is_first_worker():
  207. #fleet.load_inference_model(init_model_path, mode=int(model_mode))
  208. #fleet.barrier_worker()
  209. save_model_path = self.config.get("runner.model_save_path")
  210. if save_model_path and (not os.path.exists(save_model_path)):
  211. os.makedirs(save_model_path)
  212. reader_type = self.config.get("runner.reader_type", "QueueDataset")
  213. epochs = int(self.config.get("runner.epochs"))
  214. sync_mode = self.config.get("runner.sync_mode")
  215. opt_info = paddle.static.default_main_program()._fleet_opt
  216. if use_auc is True:
  217. opt_info['stat_var_names'] = [
  218. self.model.stat_pos.name, self.model.stat_neg.name
  219. ]
  220. else:
  221. opt_info['stat_var_names'] = []
  222. if reader_type == "InmemoryDataset":
  223. self.reader.load_into_memory()
  224. fleet.load_inference_model(
  225. init_model_path,
  226. mode=int(model_mode))
  227. epoch_start_time = time.time()
  228. epoch = 0
  229. if sync_mode == "heter":
  230. self.heter_train_loop(epoch)
  231. elif reader_type == "QueueDataset":
  232. self.dataset_train_loop(epoch)
  233. elif reader_type == "InmemoryDataset":
  234. self.dataset_train_loop(epoch)
  235. epoch_time = time.time() - epoch_start_time
  236. logger.info(
  237. "using time {} second, ips {}/sec.".format(epoch_time, self.count_method))
  238. if reader_type == "InmemoryDataset":
  239. self.reader.release_memory()
  240. def init_reader(self):
  241. if fleet.is_server():
  242. return
  243. self.config["runner.reader_type"] = self.config.get(
  244. "runner.reader_type", "QueueDataset")
  245. self.reader, self.file_list = get_infer_reader(self.input_data, config)
  246. self.example_nums = 0
  247. self.count_method = self.config.get("runner.example_count_method",
  248. "example")
  249. def dataset_train_loop(self, epoch=0):
  250. logger.info("Epoch: {}, Running Dataset Begin.".format(epoch))
  251. fetch_info = [
  252. "Epoch {} Var {}".format(epoch, var_name)
  253. for var_name in self.metrics
  254. ]
  255. fetch_vars = [var for _, var in self.metrics.items()]
  256. print_step = int(config.get("runner.print_interval"))
  257. debug = config.get("runner.dataset_debug", False)
  258. if config.get("runner.need_dump"):
  259. debug = True
  260. dump_fields_path = "{}/{}".format(
  261. config.get("runner.dump_fields_path"), epoch)
  262. set_dump_config(paddle.static.default_main_program(), {
  263. "dump_fields_path": dump_fields_path,
  264. "dump_fields": config.get("runner.dump_fields")
  265. })
  266. # 设置输出文件路径
  267. output_dir = config.get("runner.inference_output_dir", "inference_results")
  268. output_file = os.path.join(output_dir, f"epoch_{epoch}_results.jsonl")
  269. # 创建处理器实例
  270. fetch_handler = InferenceFetchHandler(output_file)
  271. print(paddle.static.default_main_program()._fleet_opt)
  272. results = self.exe.infer_from_dataset(
  273. program=paddle.static.default_main_program(),
  274. dataset=self.reader,
  275. fetch_list=fetch_vars,
  276. fetch_info=fetch_info,
  277. print_period=print_step,
  278. debug=debug,
  279. fetch_handler=fetch_handler)
  280. print("results {}".format(results))
  281. fetch_handler.finish()
  282. def heter_train_loop(self, epoch):
  283. logger.info(
  284. "Epoch: {}, Running Begin. Check running metrics at heter_log".
  285. format(epoch))
  286. reader_type = self.config.get("runner.reader_type")
  287. if reader_type == "QueueDataset":
  288. self.exe.infer_from_dataset(
  289. program=paddle.static.default_main_program(),
  290. dataset=self.reader,
  291. debug=config.get("runner.dataset_debug"))
  292. elif reader_type == "DataLoader":
  293. batch_id = 0
  294. train_run_cost = 0.0
  295. total_examples = 0
  296. self.reader.start()
  297. while True:
  298. try:
  299. train_start = time.time()
  300. # --------------------------------------------------- #
  301. self.exe.run(program=paddle.static.default_main_program())
  302. # --------------------------------------------------- #
  303. train_run_cost += time.time() - train_start
  304. total_examples += self.config.get("runner.batch_size")
  305. batch_id += 1
  306. print_step = int(config.get("runner.print_period"))
  307. if batch_id % print_step == 0:
  308. profiler_string = ""
  309. profiler_string += "avg_batch_cost: {} sec, ".format(
  310. format((train_run_cost) / print_step, '.5f'))
  311. profiler_string += "avg_samples: {}, ".format(
  312. format(total_examples / print_step, '.5f'))
  313. profiler_string += "ips: {} {}/sec ".format(
  314. format(total_examples / (train_run_cost), '.5f'),
  315. self.count_method)
  316. logger.info("Epoch: {}, Batch: {}, {}".format(
  317. epoch, batch_id, profiler_string))
  318. train_run_cost = 0.0
  319. total_examples = 0
  320. except paddle.core.EOFException:
  321. self.reader.reset()
  322. break
  323. def record_result(self):
  324. logger.info("train_result_dict: {}".format(self.train_result_dict))
  325. with open("./train_result_dict.txt", 'w+') as f:
  326. f.write(str(self.train_result_dict))
  327. if __name__ == "__main__":
  328. paddle.enable_static()
  329. config = parse_args()
  330. os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))
  331. benchmark_main = Main(config)
  332. benchmark_main.run()