|
@@ -0,0 +1,362 @@
|
|
|
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
|
+#
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
+# You may obtain a copy of the License at
|
|
|
+#
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+#
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
+# limitations under the License.
|
|
|
+
|
|
|
+from __future__ import print_function
|
|
|
+import os
|
|
|
+os.environ['FLAGS_enable_pir_api'] = '0'
|
|
|
+from utils.static_ps.reader_helper_hdfs import get_infer_reader
|
|
|
+from utils.static_ps.program_helper import get_model, get_strategy, set_dump_config
|
|
|
+from utils.static_ps.metric_helper import set_zero, get_global_auc
|
|
|
+from utils.static_ps.common_ps import YamlHelper, is_distributed_env
|
|
|
+import argparse
|
|
|
+import time
|
|
|
+import sys
|
|
|
+import paddle.distributed.fleet as fleet
|
|
|
+import paddle.distributed.fleet.base.role_maker as role_maker
|
|
|
+import paddle
|
|
|
+from paddle.base.executor import FetchHandler
|
|
|
+import queue
|
|
|
+import threading
|
|
|
+
|
|
|
+import warnings
|
|
|
+import logging
|
|
|
+import ast
|
|
|
+import numpy as np
|
|
|
+import struct
|
|
|
+from utils.utils_single import auc
|
|
|
+from utils.oss_client import HangZhouOSSClient
|
|
|
+import utils.compress as compress
|
|
|
+
|
|
|
+__dir__ = os.path.dirname(os.path.abspath(__file__))
|
|
|
+sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
|
|
+
|
|
|
+root_loger = logging.getLogger()
|
|
|
+for handler in root_loger.handlers[:]:
|
|
|
+ root_loger.removeHandler(handler)
|
|
|
+logging.basicConfig(
|
|
|
+ format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
+
|
|
|
+
|
|
|
+import json
|
|
|
+
|
|
|
+class InferenceFetchHandler(FetchHandler):
|
|
|
+ def __init__(self, var_dict, output_file, batch_size=1000):
|
|
|
+ super().__init__(var_dict=var_dict, period_secs=1)
|
|
|
+ self.output_file = output_file
|
|
|
+ self.batch_size = batch_size
|
|
|
+ self.result_queue = queue.Queue()
|
|
|
+ self.writer_thread = threading.Thread(target=self._writer)
|
|
|
+ self.writer_thread.daemon = True # 设置为守护线程
|
|
|
+ self.writer_thread.start()
|
|
|
+
|
|
|
+ # 创建输出目录(如果不存在)
|
|
|
+ output_dir = os.path.dirname(output_file)
|
|
|
+ if not os.path.exists(output_dir):
|
|
|
+ os.makedirs(output_dir)
|
|
|
+ # 创建或清空输出文件
|
|
|
+ with open(self.output_file, 'w') as f:
|
|
|
+ f.write('')
|
|
|
+
|
|
|
+ def handler(self, fetch_vars):
|
|
|
+ """处理每批次的推理结果"""
|
|
|
+ result_dict = {}
|
|
|
+
|
|
|
+ for key in fetch_vars:
|
|
|
+ # 转换数据类型
|
|
|
+ if type(fetch_vars[key]) is np.ndarray:
|
|
|
+ result = fetch_vars[key][0].tolist()
|
|
|
+ else:
|
|
|
+ result = fetch_vars[key]
|
|
|
+ result_dict[key] = result
|
|
|
+ self.result_queue.put(result_dict) # 将结果放入队列
|
|
|
+
|
|
|
+ def _writer(self):
|
|
|
+ batch = []
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ result_dict = self.result_queue.get(timeout=1) # 非阻塞获取
|
|
|
+ logger.info("write vector {} {}".format(json.dumps(result_dict), len(batch)))
|
|
|
+ batch.append(result_dict)
|
|
|
+ if len(batch) >= self.batch_size:
|
|
|
+ logger.info("write vector")
|
|
|
+ with open(self.output_file, 'a') as f:
|
|
|
+ for result in batch:
|
|
|
+ f.write(json.dumps(result) + '\n')
|
|
|
+ batch = []
|
|
|
+ except queue.Empty:
|
|
|
+ pass
|
|
|
+
|
|
|
+ def _write_batch(self, batch):
|
|
|
+ with open(self.output_file, 'a') as f:
|
|
|
+ for result in batch:
|
|
|
+ f.write(json.dumps(result) + '\n')
|
|
|
+
|
|
|
+ def flush(self):
|
|
|
+ """确保所有结果都被写入文件"""
|
|
|
+ # 等待队列中剩余的结果被处理
|
|
|
+ self.result_queue.join()
|
|
|
+ # 写入最后一批结果
|
|
|
+ self._write_batch(self.result_queue.queue)
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+def parse_args():
|
|
|
+ parser = argparse.ArgumentParser("PaddleRec train script")
|
|
|
+ parser.add_argument("-o", "--opt", nargs='*', type=str)
|
|
|
+ parser.add_argument(
|
|
|
+ '-m',
|
|
|
+ '--config_yaml',
|
|
|
+ type=str,
|
|
|
+ required=True,
|
|
|
+ help='config file path')
|
|
|
+ parser.add_argument(
|
|
|
+ '-bf16',
|
|
|
+ '--pure_bf16',
|
|
|
+ type=ast.literal_eval,
|
|
|
+ default=False,
|
|
|
+ help="whether use bf16")
|
|
|
+ args = parser.parse_args()
|
|
|
+ args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
|
|
|
+ yaml_helper = YamlHelper()
|
|
|
+ config = yaml_helper.load_yaml(args.config_yaml)
|
|
|
+ # modify config from command
|
|
|
+ if args.opt:
|
|
|
+ for parameter in args.opt:
|
|
|
+ parameter = parameter.strip()
|
|
|
+ key, value = parameter.split("=")
|
|
|
+ if type(config.get(key)) is int:
|
|
|
+ value = int(value)
|
|
|
+ if type(config.get(key)) is float:
|
|
|
+ value = float(value)
|
|
|
+ if type(config.get(key)) is bool:
|
|
|
+ value = (True if value.lower() == "true" else False)
|
|
|
+ config[key] = value
|
|
|
+ config["yaml_path"] = args.config_yaml
|
|
|
+ config["config_abs_dir"] = args.abs_dir
|
|
|
+ config["pure_bf16"] = args.pure_bf16
|
|
|
+ yaml_helper.print_yaml(config)
|
|
|
+ return config
|
|
|
+
|
|
|
+
|
|
|
+def bf16_to_fp32(val):
|
|
|
+ return np.float32(struct.unpack('<f', struct.pack('<I', val << 16))[0])
|
|
|
+
|
|
|
+
|
|
|
+class Main(object):
|
|
|
+ def __init__(self, config):
|
|
|
+ self.metrics = {}
|
|
|
+ self.config = config
|
|
|
+ self.input_data = None
|
|
|
+ self.reader = None
|
|
|
+ self.exe = None
|
|
|
+ self.train_result_dict = {}
|
|
|
+ self.train_result_dict["speed"] = []
|
|
|
+ self.train_result_dict["auc"] = []
|
|
|
+ self.model = None
|
|
|
+ self.pure_bf16 = self.config['pure_bf16']
|
|
|
+
|
|
|
+ def run(self):
|
|
|
+ self.init_fleet_with_gloo()
|
|
|
+ self.network()
|
|
|
+ if fleet.is_server():
|
|
|
+ self.run_server()
|
|
|
+ elif fleet.is_worker():
|
|
|
+ self.run_worker()
|
|
|
+ fleet.stop_worker()
|
|
|
+ self.record_result()
|
|
|
+ logger.info("Run Success, Exit.")
|
|
|
+
|
|
|
+ def init_fleet_with_gloo(self,use_gloo=True):
|
|
|
+ if use_gloo:
|
|
|
+ os.environ["PADDLE_WITH_GLOO"] = "0"
|
|
|
+ role = role_maker.PaddleCloudRoleMaker(
|
|
|
+ is_collective=False,
|
|
|
+ init_gloo=False
|
|
|
+ )
|
|
|
+ fleet.init(role)
|
|
|
+ else:
|
|
|
+ fleet.init()
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ def network(self):
|
|
|
+ self.model = get_model(self.config)
|
|
|
+ self.input_data = self.model.create_feeds(is_infer=True)
|
|
|
+ self.inference_feed_var = self.input_data
|
|
|
+ self.init_reader()
|
|
|
+ self.metrics = self.model.net(self.inference_feed_var,is_infer=True)
|
|
|
+ self.inference_target_var = self.model.inference_target_var
|
|
|
+ logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
|
|
|
+ self.model.create_optimizer(get_strategy(self.config),is_infer=True)
|
|
|
+
|
|
|
+ def run_server(self):
|
|
|
+ logger.info("Run Server Begin")
|
|
|
+ fleet.init_server(config.get("runner.warmup_model_path"))
|
|
|
+ fleet.run_server()
|
|
|
+
|
|
|
+ def run_worker(self):
|
|
|
+ logger.info("Run Worker Begin")
|
|
|
+ use_cuda = int(config.get("runner.use_gpu"))
|
|
|
+ use_auc = config.get("runner.infer_use_auc", False)
|
|
|
+ place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()
|
|
|
+ self.exe = paddle.static.Executor(place)
|
|
|
+
|
|
|
+ with open("./{}_worker_main_program.prototxt".format(
|
|
|
+ fleet.worker_index()), 'w+') as f:
|
|
|
+ f.write(str(paddle.static.default_main_program()))
|
|
|
+ with open("./{}_worker_startup_program.prototxt".format(
|
|
|
+ fleet.worker_index()), 'w+') as f:
|
|
|
+ f.write(str(paddle.static.default_startup_program()))
|
|
|
+
|
|
|
+ self.exe.run(paddle.static.default_startup_program())
|
|
|
+ if self.pure_bf16:
|
|
|
+ self.model.optimizer.amp_init(self.exe.place)
|
|
|
+ fleet.init_worker()
|
|
|
+
|
|
|
+ init_model_path = config.get("runner.infer_load_path")
|
|
|
+ model_mode = config.get("runner.model_mode", 0)
|
|
|
+ client = HangZhouOSSClient("art-recommend")
|
|
|
+ oss_object_name = self.config.get("runner.oss_object_name", "dyp/model.tar.gz")
|
|
|
+ client.get_object_to_file(oss_object_name, "model.tar.gz")
|
|
|
+ compress.uncompress_tar("model.tar.gz", init_model_path)
|
|
|
+ assert os.path.exists(init_model_path)
|
|
|
+
|
|
|
+ #if fleet.is_first_worker():
|
|
|
+ #fleet.load_inference_model(init_model_path, mode=int(model_mode))
|
|
|
+ #fleet.barrier_worker()
|
|
|
+
|
|
|
+
|
|
|
+ reader_type = self.config.get("runner.reader_type", "QueueDataset")
|
|
|
+ epochs = int(self.config.get("runner.epochs"))
|
|
|
+ sync_mode = self.config.get("runner.sync_mode")
|
|
|
+ opt_info = paddle.static.default_main_program()._fleet_opt
|
|
|
+ if use_auc is True:
|
|
|
+ opt_info['stat_var_names'] = [
|
|
|
+ self.model.stat_pos.name, self.model.stat_neg.name
|
|
|
+ ]
|
|
|
+ else:
|
|
|
+ opt_info['stat_var_names'] = []
|
|
|
+
|
|
|
+ if reader_type == "InmemoryDataset":
|
|
|
+ self.reader.load_into_memory()
|
|
|
+
|
|
|
+ fleet.load_inference_model(
|
|
|
+ init_model_path,
|
|
|
+ mode=int(model_mode))
|
|
|
+ epoch_start_time = time.time()
|
|
|
+ epoch = 0
|
|
|
+ if sync_mode == "heter":
|
|
|
+ self.heter_train_loop(epoch)
|
|
|
+ elif reader_type == "QueueDataset":
|
|
|
+ self.dataset_train_loop(epoch)
|
|
|
+ elif reader_type == "InmemoryDataset":
|
|
|
+ self.dataset_train_loop(epoch)
|
|
|
+
|
|
|
+ epoch_time = time.time() - epoch_start_time
|
|
|
+ logger.info(
|
|
|
+ "using time {} second, ips {}/sec.".format(epoch_time, self.count_method))
|
|
|
+ while True:
|
|
|
+ time.sleep(300)
|
|
|
+ continue;
|
|
|
+ if reader_type == "InmemoryDataset":
|
|
|
+ self.reader.release_memory()
|
|
|
+
|
|
|
+ def init_reader(self):
|
|
|
+ if fleet.is_server():
|
|
|
+ return
|
|
|
+ self.config["runner.reader_type"] = self.config.get(
|
|
|
+ "runner.reader_type", "QueueDataset")
|
|
|
+ self.reader, self.file_list = get_infer_reader(self.input_data, config)
|
|
|
+ self.example_nums = 0
|
|
|
+ self.count_method = self.config.get("runner.example_count_method",
|
|
|
+ "example")
|
|
|
+
|
|
|
+ def dataset_train_loop(self, epoch=0):
|
|
|
+ logger.info("Epoch: {}, Running Dataset Begin.".format(epoch))
|
|
|
+
|
|
|
+ fetch_info = [
|
|
|
+ "Epoch {} Var {}".format(epoch, var_name)
|
|
|
+ for var_name in self.metrics
|
|
|
+ ]
|
|
|
+
|
|
|
+ fetch_vars = [var for _, var in self.metrics.items()]
|
|
|
+ input_data_names = [data.name for data in self.input_data]
|
|
|
+ test_dataloader = self.reader
|
|
|
+
|
|
|
+ for batch_id, batch_data in enumerate(test_dataloader()):
|
|
|
+ fetch_batch_var = exe.run(
|
|
|
+ program=paddle.static.default_main_program(),
|
|
|
+ feed=dict(zip(input_data_names, batch_data)),
|
|
|
+ fetch_list=fetch_vars)
|
|
|
+
|
|
|
+ logger.info("fetch_batch_var : {}".format(fetch_batch_var))
|
|
|
+
|
|
|
+
|
|
|
+ def heter_train_loop(self, epoch):
|
|
|
+ logger.info(
|
|
|
+ "Epoch: {}, Running Begin. Check running metrics at heter_log".
|
|
|
+ format(epoch))
|
|
|
+ reader_type = self.config.get("runner.reader_type")
|
|
|
+ if reader_type == "QueueDataset":
|
|
|
+ self.exe.infer_from_dataset(
|
|
|
+ program=paddle.static.default_main_program(),
|
|
|
+ dataset=self.reader,
|
|
|
+ debug=config.get("runner.dataset_debug"))
|
|
|
+ elif reader_type == "DataLoader":
|
|
|
+ batch_id = 0
|
|
|
+ train_run_cost = 0.0
|
|
|
+ total_examples = 0
|
|
|
+ self.reader.start()
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ train_start = time.time()
|
|
|
+ # --------------------------------------------------- #
|
|
|
+ self.exe.run(program=paddle.static.default_main_program())
|
|
|
+ # --------------------------------------------------- #
|
|
|
+ train_run_cost += time.time() - train_start
|
|
|
+ total_examples += self.config.get("runner.batch_size")
|
|
|
+ batch_id += 1
|
|
|
+ print_step = int(config.get("runner.print_period"))
|
|
|
+ if batch_id % print_step == 0:
|
|
|
+ profiler_string = ""
|
|
|
+ profiler_string += "avg_batch_cost: {} sec, ".format(
|
|
|
+ format((train_run_cost) / print_step, '.5f'))
|
|
|
+ profiler_string += "avg_samples: {}, ".format(
|
|
|
+ format(total_examples / print_step, '.5f'))
|
|
|
+ profiler_string += "ips: {} {}/sec ".format(
|
|
|
+ format(total_examples / (train_run_cost), '.5f'),
|
|
|
+ self.count_method)
|
|
|
+ logger.info("Epoch: {}, Batch: {}, {}".format(
|
|
|
+ epoch, batch_id, profiler_string))
|
|
|
+ train_run_cost = 0.0
|
|
|
+ total_examples = 0
|
|
|
+ except paddle.core.EOFException:
|
|
|
+ self.reader.reset()
|
|
|
+ break
|
|
|
+
|
|
|
+ def record_result(self):
|
|
|
+ logger.info("train_result_dict: {}".format(self.train_result_dict))
|
|
|
+ with open("./train_result_dict.txt", 'w+') as f:
|
|
|
+ f.write(str(self.train_result_dict))
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ paddle.enable_static()
|
|
|
+
|
|
|
+ config = parse_args()
|
|
|
+ os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))
|
|
|
+ benchmark_main = Main(config)
|
|
|
+
|
|
|
+ benchmark_main.run()
|