| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at##     http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.from __future__ import print_functionimport osos.environ['FLAGS_enable_pir_api'] = '0'from utils.static_ps.reader_helper import get_reader, get_infer_reader, get_example_num, get_file_list, get_word_numfrom utils.static_ps.program_helper import get_model, get_strategy, set_dump_configfrom utils.static_ps.common_ps import YamlHelper, is_distributed_envimport argparseimport timeimport sysimport paddle.distributed.fleet as fleetimport paddle.distributed.fleet.base.role_maker as role_makerfrom paddle.distributed.ps.coordinator import FLClientimport paddleimport warningsimport loggingimport astimport numpy as npimport struct__dir__ = os.path.dirname(os.path.abspath(__file__))sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))logging.basicConfig(    format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)logger = logging.getLogger(__name__)def parse_args():    parser = argparse.ArgumentParser("PaddleRec train script")    parser.add_argument(        '-m',        '--config_yaml',        type=str,        required=True,        help='config file path')    parser.add_argument(        '-bf16',        '--pure_bf16',        type=ast.literal_eval,        default=False,        help="whether use bf16")    args = parser.parse_args()    args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))    yaml_helper = YamlHelper()    config = yaml_helper.load_yaml(args.config_yaml)    config["yaml_path"] = args.config_yaml    config["config_abs_dir"] = args.abs_dir    config["pure_bf16"] = args.pure_bf16    yaml_helper.print_yaml(config)    return configdef bf16_to_fp32(val):    return np.float32(struct.unpack('<f', struct.pack('<I', val << 16))[0])class MyFLClient(FLClient):    def __init__(self):        passclass Trainer(object):    def __init__(self, config):        self.metrics = {}        self.config = config        self.input_data = None        self.train_dataset = None        self.test_dataset = None        self.model = None        self.pure_bf16 = self.config['pure_bf16']        self.use_cuda = int(self.config.get("runner.use_gpu"))        self.place = paddle.CUDAPlace(0) if self.use_cuda else paddle.CPUPlace(        )        self.role = None    def run(self):        self.init_fleet()        self.init_network()        if fleet.is_server():            self.run_server()        elif fleet.is_worker():            self.init_reader()            self.run_worker()        elif fleet.is_coordinator():            self.run_coordinator()        logger.info("Run Success, Exit.")    def init_fleet(self, use_gloo=True):        if use_gloo:            os.environ["PADDLE_WITH_GLOO"] = "1"            self.role = role_maker.PaddleCloudRoleMaker()            fleet.init(self.role)        else:            fleet.init()    def init_network(self):        self.model = get_model(self.config)        self.input_data = self.model.create_feeds()        self.metrics = self.model.net(self.input_data)        self.model.create_optimizer(get_strategy(self.config))  ## get_strategy        if self.pure_bf16:            self.model.optimizer.amp_init(self.place)    def init_reader(self):        self.train_dataset, self.train_file_list = get_reader(self.input_data,                                                              config)        self.test_dataset, self.test_file_list = get_infer_reader(            self.input_data, config)        if self.role is not None:            self.fl_client = MyFLClient()            self.fl_client.set_basic_config(self.role, self.config,                                            self.metrics)        else:            raise ValueError("self.role is none")        self.fl_client.set_train_dataset_info(self.train_dataset,                                              self.train_file_list)        self.fl_client.set_test_dataset_info(self.test_dataset,                                             self.test_file_list)        example_nums = 0        self.count_method = self.config.get("runner.example_count_method",                                            "example")        if self.count_method == "example":            example_nums = get_example_num(self.train_file_list)        elif self.count_method == "word":            example_nums = get_word_num(self.train_file_list)        else:            raise ValueError(                "Set static_benchmark.example_count_method for example / word for example count."            )        self.fl_client.set_train_example_num(example_nums)    def run_coordinator(self):        logger.info("Run Coordinator Begin")        fleet.init_coordinator()        fleet.make_fl_strategy()    def run_server(self):        logger.info("Run Server Begin")        fleet.init_server(config.get("runner.warmup_model_path"))        fleet.run_server()    def run_worker(self):        logger.info("Run Worker Begin")        self.fl_client.run()if __name__ == "__main__":    paddle.enable_static()    config = parse_args()    os.environ["CPU_NUM"] = str(config.get("runner.thread_num"))    trainer = Trainer(config)    trainer.run()
 |