123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import print_function
- import os
- os.environ['FLAGS_enable_pir_api'] = '0'
- from utils.static_ps.reader_helper_hdfs import get_infer_reader
- from utils.static_ps.program_helper import get_model, get_strategy, set_dump_config
- from utils.static_ps.metric_helper import set_zero, get_global_auc
- from utils.static_ps.common_ps import YamlHelper, is_distributed_env
- import argparse
- import time
- import sys
- import paddle.distributed.fleet as fleet
- import paddle.distributed.fleet.base.role_maker as role_maker
- import paddle
- from paddle.base.executor import FetchHandler
- import threading
- import warnings
- import logging
- import ast
- import numpy as np
- import struct
- from utils.utils_single import auc
- from utils.oss_client import HangZhouOSSClient
- import utils.compress as compress
- __dir__ = os.path.dirname(os.path.abspath(__file__))
- sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
- root_loger = logging.getLogger()
- for handler in root_loger.handlers[:]:
- root_loger.removeHandler(handler)
- logging.basicConfig(
- format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
- logger = logging.getLogger(__name__)
- import json
- class InferenceFetchHandler(FetchHandler):
- def __init__(self, var_dict, output_file, batch_size=1000):
- super().__init__(var_dict=var_dict, period_secs=1)
- self.output_file = output_file
- self.batch_size = batch_size
- self.current_batch = []
- self.total_samples = 0
-
- # 创建输出目录(如果不存在)
- output_dir = os.path.dirname(output_file)
- if not os.path.exists(output_dir):
- os.makedirs(output_dir)
- # 创建或清空输出文件
- with open(self.output_file, 'w') as f:
- f.write('')
-
- def handler(self, fetch_vars):
- """处理每批次的推理结果"""
- result_dict = {}
- logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars))
- for var_name, var_value in fetch_vars.items():
- # 转换数据类型
- if isinstance(var_value, np.ndarray):
- result = var_value.tolist()
- else:
- result = var_value
- result_dict[var_name] = result
-
- self.current_batch.append(result_dict)
- self.total_samples += len(result_dict.get(list(result_dict.keys())[0], []))
-
- # 当累积足够的结果时,写入文件
- if len(self.current_batch) >= self.batch_size:
- self._write_batch()
- logger.info(f"Saved {self.total_samples} samples to {self.output_file}")
-
- def _write_batch(self):
- """将批次结果写入文件"""
- with open(self.output_file, 'a') as f:
- for result in self.current_batch:
- f.write(json.dumps(result) + '\n')
- self.current_batch = []
-
- def finish(self):
- logger.info("InferenceFetchHandler finish")
- """确保所有剩余结果都被保存"""
- if self.current_batch:
- self._write_batch()
- logger.info(f"Final save: total {self.total_samples} samples saved to {self.output_file}")
- self.done_event.set()
-
- def parse_args():
- parser = argparse.ArgumentParser("PaddleRec train script")
- parser.add_argument("-o", "--opt", nargs='*', type=str)
- parser.add_argument(
- '-m',
- '--config_yaml',
- type=str,
- required=True,
- help='config file path')
- parser.add_argument(
- '-bf16',
- '--pure_bf16',
- type=ast.literal_eval,
- default=False,
- help="whether use bf16")
- args = parser.parse_args()
- args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
- yaml_helper = YamlHelper()
- config = yaml_helper.load_yaml(args.config_yaml)
- # modify config from command
- if args.opt:
- for parameter in args.opt:
- parameter = parameter.strip()
- key, value = parameter.split("=")
- if type(config.get(key)) is int:
- value = int(value)
- if type(config.get(key)) is float:
- value = float(value)
- if type(config.get(key)) is bool:
- value = (True if value.lower() == "true" else False)
- config[key] = value
- config["yaml_path"] = args.config_yaml
- config["config_abs_dir"] = args.abs_dir
- config["pure_bf16"] = args.pure_bf16
- yaml_helper.print_yaml(config)
- return config
- def bf16_to_fp32(val):
- return np.float32(struct.unpack('<f', struct.pack('<I', val << 16))[0])
- class Main(object):
- def __init__(self, config):
- self.metrics = {}
- self.config = config
- self.input_data = None
- self.reader = None
- self.exe = None
- self.train_result_dict = {}
- self.train_result_dict["speed"] = []
- self.train_result_dict["auc"] = []
- self.model = None
- self.pure_bf16 = self.config['pure_bf16']
- def run(self):
- self.init_fleet_with_gloo()
- self.network()
- if fleet.is_server():
- self.run_server()
- elif fleet.is_worker():
- self.run_worker()
- fleet.stop_worker()
- self.record_result()
- logger.info("Run Success, Exit.")
- def init_fleet_with_gloo(self,use_gloo=True):
- if use_gloo:
- os.environ["PADDLE_WITH_GLOO"] = "0"
- role = role_maker.PaddleCloudRoleMaker(
- is_collective=False,
- init_gloo=False
- )
- fleet.init(role)
- else:
- fleet.init()
- def network(self):
- self.model = get_model(self.config)
- self.input_data = self.model.create_feeds(is_infer=True)
- self.inference_feed_var = self.input_data
- self.init_reader()
- self.metrics = self.model.net(self.inference_feed_var,is_infer=True)
- self.inference_target_var = self.model.inference_target_var
- logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
- self.model.create_optimizer(get_strategy(self.config),is_infer=True)
- def run_server(self):
- logger.info("Run Server Begin")
- fleet.init_server(config.get("runner.warmup_model_path"))
- fleet.run_server()
- def run_worker(self):
- logger.info("Run Worker Begin")
- use_cuda = int(config.get("runner.use_gpu"))
- use_auc = config.get("runner.infer_use_auc", False)
- place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
- self.exe = paddle.static.Executor(place)
- with open("./{}_worker_main_program.prototxt".format(
- fleet.worker_index()), 'w+') as f:
- f.write(str(paddle.static.default_main_program()))
- with open("./{}_worker_startup_program.prototxt".format(
- fleet.worker_index()), 'w+') as f:
- f.write(str(paddle.static.default_startup_program()))
- self.exe.run(paddle.static.default_startup_program())
- if self.pure_bf16:
- self.model.optimizer.amp_init(self.exe.place)
- fleet.init_worker()
- init_model_path = config.get("runner.infer_load_path")
- model_mode = config.get("runner.model_mode", 0)
- client = HangZhouOSSClient("art-recommend")
- oss_object_name = self.config.get("runner.oss_object_name", "dyp/model.tar.gz")
- client.get_object_to_file(oss_object_name, "model.tar.gz")
- compress.uncompress_tar("model.tar.gz", init_model_path)
- assert os.path.exists(init_model_path)
- #if fleet.is_first_worker():
- #fleet.load_inference_model(init_model_path, mode=int(model_mode))
- #fleet.barrier_worker()
- reader_type = self.config.get("runner.reader_type", "QueueDataset")
- epochs = int(self.config.get("runner.epochs"))
- sync_mode = self.config.get("runner.sync_mode")
- opt_info = paddle.static.default_main_program()._fleet_opt
- if use_auc is True:
- opt_info['stat_var_names'] = [
- self.model.stat_pos.name, self.model.stat_neg.name
- ]
- else:
- opt_info['stat_var_names'] = []
- if reader_type == "InmemoryDataset":
- self.reader.load_into_memory()
- fleet.load_inference_model(
- init_model_path,
- mode=int(model_mode))
- epoch_start_time = time.time()
- epoch = 0
- if sync_mode == "heter":
- self.heter_train_loop(epoch)
- elif reader_type == "QueueDataset":
- self.dataset_train_loop(epoch)
- elif reader_type == "InmemoryDataset":
- self.dataset_train_loop(epoch)
- epoch_time = time.time() - epoch_start_time
- logger.info(
- "using time {} second, ips {}/sec.".format(epoch_time, self.count_method))
- while True:
- time.sleep(300)
- continue;
- if reader_type == "InmemoryDataset":
- self.reader.release_memory()
- def init_reader(self):
- if fleet.is_server():
- return
- self.config["runner.reader_type"] = self.config.get(
- "runner.reader_type", "QueueDataset")
- self.reader, self.file_list = get_infer_reader(self.input_data, config)
- self.example_nums = 0
- self.count_method = self.config.get("runner.example_count_method",
- "example")
- def dataset_train_loop(self, epoch=0):
- logger.info("Epoch: {}, Running Dataset Begin.".format(epoch))
- fetch_info = [
- "Epoch {} Var {}".format(epoch, var_name)
- for var_name in self.metrics
- ]
- fetch_vars = [var for _, var in self.metrics.items()]
- print_step = int(config.get("runner.print_interval"))
- debug = config.get("runner.dataset_debug", False)
- if config.get("runner.need_dump"):
- debug = True
- dump_fields_path = "{}/{}".format(
- config.get("runner.dump_fields_path"), epoch)
- set_dump_config(paddle.static.default_main_program(), {
- "dump_fields_path": dump_fields_path,
- "dump_fields": config.get("runner.dump_fields")
- })
-
- # 设置输出文件路径
- output_dir = config.get("runner.inference_output_dir", "inference_results")
- output_file = os.path.join(output_dir, f"epoch_{epoch}_results.jsonl")
-
- # 创建处理器实例
- fetch_handler = InferenceFetchHandler(var_dict = self.metrics, output_file =output_file)
- # fetch_handler.set_var_dict(self.metrics)
- print(paddle.static.default_main_program()._fleet_opt)
- self.exe.infer_from_dataset(
- program=paddle.static.default_main_program(),
- dataset=self.reader,
- fetch_list=fetch_vars,
- fetch_info=fetch_info,
- print_period=print_step,
- debug=debug,
- fetch_handler=fetch_handler)
- def heter_train_loop(self, epoch):
- logger.info(
- "Epoch: {}, Running Begin. Check running metrics at heter_log".
- format(epoch))
- reader_type = self.config.get("runner.reader_type")
- if reader_type == "QueueDataset":
- self.exe.infer_from_dataset(
- program=paddle.static.default_main_program(),
- dataset=self.reader,
- debug=config.get("runner.dataset_debug"))
- elif reader_type == "DataLoader":
- batch_id = 0
- train_run_cost = 0.0
- total_examples = 0
- self.reader.start()
- while True:
- try:
- train_start = time.time()
- # --------------------------------------------------- #
- self.exe.run(program=paddle.static.default_main_program())
- # --------------------------------------------------- #
- train_run_cost += time.time() - train_start
- total_examples += self.config.get("runner.batch_size")
- batch_id += 1
- print_step = int(config.get("runner.print_period"))
- if batch_id % print_step == 0:
- profiler_string = ""
- profiler_string += "avg_batch_cost: {} sec, ".format(
- format((train_run_cost) / print_step, '.5f'))
- profiler_string += "avg_samples: {}, ".format(
- format(total_examples / print_step, '.5f'))
- profiler_string += "ips: {} {}/sec ".format(
- format(total_examples / (train_run_cost), '.5f'),
- self.count_method)
- logger.info("Epoch: {}, Batch: {}, {}".format(
- epoch, batch_id, profiler_string))
- train_run_cost = 0.0
- total_examples = 0
- except paddle.core.EOFException:
- self.reader.reset()
- break
- def record_result(self):
- logger.info("train_result_dict: {}".format(self.train_result_dict))
- with open("./train_result_dict.txt", 'w+') as f:
- f.write(str(self.train_result_dict))
- if __name__ == "__main__":
- paddle.enable_static()
- config = parse_args()
- os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))
- benchmark_main = Main(config)
- benchmark_main.run()
|