static_ps_infer_v2.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import print_function
  15. import os
  16. os.environ['FLAGS_enable_pir_api'] = '0'
  17. from utils.static_ps.reader_helper_hdfs import get_infer_reader
  18. from utils.static_ps.program_helper import get_model, get_strategy, set_dump_config
  19. from utils.static_ps.metric_helper import set_zero, get_global_auc
  20. from utils.static_ps.common_ps import YamlHelper, is_distributed_env
  21. import argparse
  22. import time
  23. import sys
  24. import paddle.distributed.fleet as fleet
  25. import paddle.distributed.fleet.base.role_maker as role_maker
  26. import paddle
  27. from paddle.base.executor import FetchHandler
  28. import queue
  29. import threading
  30. import warnings
  31. import logging
  32. import ast
  33. import numpy as np
  34. import struct
  35. from utils.utils_single import auc
  36. from utils.oss_client import HangZhouOSSClient
  37. import utils.compress as compress
  38. __dir__ = os.path.dirname(os.path.abspath(__file__))
  39. sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
  40. root_loger = logging.getLogger()
  41. for handler in root_loger.handlers[:]:
  42. root_loger.removeHandler(handler)
  43. logging.basicConfig(
  44. format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
  45. logger = logging.getLogger(__name__)
  46. import json
  47. class InferenceFetchHandler(FetchHandler):
  48. def __init__(self, var_dict, output_file, batch_size=1000):
  49. super().__init__(var_dict=var_dict, period_secs=1)
  50. self.output_file = output_file
  51. self.batch_size = batch_size
  52. self.result_queue = queue.Queue()
  53. self.writer_thread = threading.Thread(target=self._writer)
  54. self.writer_thread.daemon = True # 设置为守护线程
  55. self.writer_thread.start()
  56. # 创建输出目录(如果不存在)
  57. output_dir = os.path.dirname(output_file)
  58. if not os.path.exists(output_dir):
  59. os.makedirs(output_dir)
  60. # 创建或清空输出文件
  61. with open(self.output_file, 'w') as f:
  62. f.write('')
  63. def handler(self, fetch_vars):
  64. """处理每批次的推理结果"""
  65. result_dict = {}
  66. for key in fetch_vars:
  67. # 转换数据类型
  68. if type(fetch_vars[key]) is np.ndarray:
  69. result = fetch_vars[key][0]
  70. else:
  71. result = fetch_vars[key]
  72. result_dict[key] = result
  73. self.result_queue.put(result_dict) # 将结果放入队列
  74. def _writer(self):
  75. batch = []
  76. while True:
  77. try:
  78. result_dict = self.result_queue.get(timeout=1) # 非阻塞获取
  79. logger.info("write vector {}".format(result_dict))
  80. batch.append(result_dict)
  81. if len(batch) >= self.batch_size:
  82. logger.info("write vector")
  83. with open(self.output_file, 'a') as f:
  84. for result in batch:
  85. f.write(json.dumps(result) + '\n')
  86. batch = []
  87. except queue.Empty:
  88. pass
  89. def _write_batch(self, batch):
  90. with open(self.output_file, 'a') as f:
  91. for result in batch:
  92. f.write(json.dumps(result) + '\n')
  93. def flush(self):
  94. """确保所有结果都被写入文件"""
  95. # 等待队列中剩余的结果被处理
  96. self.result_queue.join()
  97. # 写入最后一批结果
  98. self._write_batch(self.result_queue.queue)
  99. def parse_args():
  100. parser = argparse.ArgumentParser("PaddleRec train script")
  101. parser.add_argument("-o", "--opt", nargs='*', type=str)
  102. parser.add_argument(
  103. '-m',
  104. '--config_yaml',
  105. type=str,
  106. required=True,
  107. help='config file path')
  108. parser.add_argument(
  109. '-bf16',
  110. '--pure_bf16',
  111. type=ast.literal_eval,
  112. default=False,
  113. help="whether use bf16")
  114. args = parser.parse_args()
  115. args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
  116. yaml_helper = YamlHelper()
  117. config = yaml_helper.load_yaml(args.config_yaml)
  118. # modify config from command
  119. if args.opt:
  120. for parameter in args.opt:
  121. parameter = parameter.strip()
  122. key, value = parameter.split("=")
  123. if type(config.get(key)) is int:
  124. value = int(value)
  125. if type(config.get(key)) is float:
  126. value = float(value)
  127. if type(config.get(key)) is bool:
  128. value = (True if value.lower() == "true" else False)
  129. config[key] = value
  130. config["yaml_path"] = args.config_yaml
  131. config["config_abs_dir"] = args.abs_dir
  132. config["pure_bf16"] = args.pure_bf16
  133. yaml_helper.print_yaml(config)
  134. return config
  135. def bf16_to_fp32(val):
  136. return np.float32(struct.unpack('<f', struct.pack('<I', val << 16))[0])
  137. class Main(object):
  138. def __init__(self, config):
  139. self.metrics = {}
  140. self.config = config
  141. self.input_data = None
  142. self.reader = None
  143. self.exe = None
  144. self.train_result_dict = {}
  145. self.train_result_dict["speed"] = []
  146. self.train_result_dict["auc"] = []
  147. self.model = None
  148. self.pure_bf16 = self.config['pure_bf16']
  149. def run(self):
  150. self.init_fleet_with_gloo()
  151. self.network()
  152. if fleet.is_server():
  153. self.run_server()
  154. elif fleet.is_worker():
  155. self.run_worker()
  156. fleet.stop_worker()
  157. self.record_result()
  158. logger.info("Run Success, Exit.")
  159. def init_fleet_with_gloo(self,use_gloo=True):
  160. if use_gloo:
  161. os.environ["PADDLE_WITH_GLOO"] = "0"
  162. role = role_maker.PaddleCloudRoleMaker(
  163. is_collective=False,
  164. init_gloo=False
  165. )
  166. fleet.init(role)
  167. else:
  168. fleet.init()
  169. def network(self):
  170. self.model = get_model(self.config)
  171. self.input_data = self.model.create_feeds(is_infer=True)
  172. self.inference_feed_var = self.input_data
  173. self.init_reader()
  174. self.metrics = self.model.net(self.inference_feed_var,is_infer=True)
  175. self.inference_target_var = self.model.inference_target_var
  176. logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
  177. self.model.create_optimizer(get_strategy(self.config),is_infer=True)
  178. def run_server(self):
  179. logger.info("Run Server Begin")
  180. fleet.init_server(config.get("runner.warmup_model_path"))
  181. fleet.run_server()
  182. def run_worker(self):
  183. logger.info("Run Worker Begin")
  184. use_cuda = int(config.get("runner.use_gpu"))
  185. use_auc = config.get("runner.infer_use_auc", False)
  186. place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
  187. self.exe = paddle.static.Executor(place)
  188. with open("./{}_worker_main_program.prototxt".format(
  189. fleet.worker_index()), 'w+') as f:
  190. f.write(str(paddle.static.default_main_program()))
  191. with open("./{}_worker_startup_program.prototxt".format(
  192. fleet.worker_index()), 'w+') as f:
  193. f.write(str(paddle.static.default_startup_program()))
  194. self.exe.run(paddle.static.default_startup_program())
  195. if self.pure_bf16:
  196. self.model.optimizer.amp_init(self.exe.place)
  197. fleet.init_worker()
  198. init_model_path = config.get("runner.infer_load_path")
  199. model_mode = config.get("runner.model_mode", 0)
  200. client = HangZhouOSSClient("art-recommend")
  201. oss_object_name = self.config.get("runner.oss_object_name", "dyp/model.tar.gz")
  202. client.get_object_to_file(oss_object_name, "model.tar.gz")
  203. compress.uncompress_tar("model.tar.gz", init_model_path)
  204. assert os.path.exists(init_model_path)
  205. #if fleet.is_first_worker():
  206. #fleet.load_inference_model(init_model_path, mode=int(model_mode))
  207. #fleet.barrier_worker()
  208. reader_type = self.config.get("runner.reader_type", "QueueDataset")
  209. epochs = int(self.config.get("runner.epochs"))
  210. sync_mode = self.config.get("runner.sync_mode")
  211. opt_info = paddle.static.default_main_program()._fleet_opt
  212. if use_auc is True:
  213. opt_info['stat_var_names'] = [
  214. self.model.stat_pos.name, self.model.stat_neg.name
  215. ]
  216. else:
  217. opt_info['stat_var_names'] = []
  218. if reader_type == "InmemoryDataset":
  219. self.reader.load_into_memory()
  220. fleet.load_inference_model(
  221. init_model_path,
  222. mode=int(model_mode))
  223. epoch_start_time = time.time()
  224. epoch = 0
  225. if sync_mode == "heter":
  226. self.heter_train_loop(epoch)
  227. elif reader_type == "QueueDataset":
  228. self.dataset_train_loop(epoch)
  229. elif reader_type == "InmemoryDataset":
  230. self.dataset_train_loop(epoch)
  231. epoch_time = time.time() - epoch_start_time
  232. logger.info(
  233. "using time {} second, ips {}/sec.".format(epoch_time, self.count_method))
  234. while True:
  235. time.sleep(300)
  236. continue;
  237. if reader_type == "InmemoryDataset":
  238. self.reader.release_memory()
  239. def init_reader(self):
  240. if fleet.is_server():
  241. return
  242. self.config["runner.reader_type"] = self.config.get(
  243. "runner.reader_type", "QueueDataset")
  244. self.reader, self.file_list = get_infer_reader(self.input_data, config)
  245. self.example_nums = 0
  246. self.count_method = self.config.get("runner.example_count_method",
  247. "example")
  248. def dataset_train_loop(self, epoch=0):
  249. logger.info("Epoch: {}, Running Dataset Begin.".format(epoch))
  250. fetch_info = [
  251. "Epoch {} Var {}".format(epoch, var_name)
  252. for var_name in self.metrics
  253. ]
  254. fetch_vars = [var for _, var in self.metrics.items()]
  255. print_step = int(config.get("runner.print_interval"))
  256. debug = config.get("runner.dataset_debug", False)
  257. if config.get("runner.need_dump"):
  258. debug = True
  259. dump_fields_path = "{}/{}".format(
  260. config.get("runner.dump_fields_path"), epoch)
  261. set_dump_config(paddle.static.default_main_program(), {
  262. "dump_fields_path": dump_fields_path,
  263. "dump_fields": config.get("runner.dump_fields")
  264. })
  265. # 设置输出文件路径
  266. output_dir = config.get("runner.inference_output_dir", "inference_results")
  267. output_file = os.path.join(output_dir, f"epoch_{epoch}_results.json")
  268. # 创建处理器实例
  269. fetch_handler = InferenceFetchHandler(var_dict = self.metrics, output_file = output_file)
  270. print(paddle.static.default_main_program()._fleet_opt)
  271. self.exe.infer_from_dataset(
  272. program=paddle.static.default_main_program(),
  273. dataset=self.reader,
  274. fetch_list=fetch_vars,
  275. fetch_info=fetch_info,
  276. print_period=print_step,
  277. debug=debug,
  278. fetch_handler=fetch_handler)
  279. fetch_handler.flush()
  280. def heter_train_loop(self, epoch):
  281. logger.info(
  282. "Epoch: {}, Running Begin. Check running metrics at heter_log".
  283. format(epoch))
  284. reader_type = self.config.get("runner.reader_type")
  285. if reader_type == "QueueDataset":
  286. self.exe.infer_from_dataset(
  287. program=paddle.static.default_main_program(),
  288. dataset=self.reader,
  289. debug=config.get("runner.dataset_debug"))
  290. elif reader_type == "DataLoader":
  291. batch_id = 0
  292. train_run_cost = 0.0
  293. total_examples = 0
  294. self.reader.start()
  295. while True:
  296. try:
  297. train_start = time.time()
  298. # --------------------------------------------------- #
  299. self.exe.run(program=paddle.static.default_main_program())
  300. # --------------------------------------------------- #
  301. train_run_cost += time.time() - train_start
  302. total_examples += self.config.get("runner.batch_size")
  303. batch_id += 1
  304. print_step = int(config.get("runner.print_period"))
  305. if batch_id % print_step == 0:
  306. profiler_string = ""
  307. profiler_string += "avg_batch_cost: {} sec, ".format(
  308. format((train_run_cost) / print_step, '.5f'))
  309. profiler_string += "avg_samples: {}, ".format(
  310. format(total_examples / print_step, '.5f'))
  311. profiler_string += "ips: {} {}/sec ".format(
  312. format(total_examples / (train_run_cost), '.5f'),
  313. self.count_method)
  314. logger.info("Epoch: {}, Batch: {}, {}".format(
  315. epoch, batch_id, profiler_string))
  316. train_run_cost = 0.0
  317. total_examples = 0
  318. except paddle.core.EOFException:
  319. self.reader.reset()
  320. break
  321. def record_result(self):
  322. logger.info("train_result_dict: {}".format(self.train_result_dict))
  323. with open("./train_result_dict.txt", 'w+') as f:
  324. f.write(str(self.train_result_dict))
  325. if __name__ == "__main__":
  326. paddle.enable_static()
  327. config = parse_args()
  328. os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))
  329. benchmark_main = Main(config)
  330. benchmark_main.run()