# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os os.environ['FLAGS_enable_pir_api'] = '0' import paddle import paddle.nn as nn import time import logging import sys import importlib __dir__ = os.path.dirname(os.path.abspath(__file__)) #sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) from utils.utils_single import load_yaml, load_dy_model_class, get_abs_model, create_data_loader from utils.save_load import load_model, save_model, save_jit_model from paddle.io import DistributedBatchSampler, DataLoader import argparse logging.basicConfig( format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) logger = logging.getLogger(__name__) def parse_args(): parser = argparse.ArgumentParser(description='paddle-rec run') parser.add_argument("-m", "--config_yaml", type=str) parser.add_argument("-o", "--opt", nargs='*', type=str) args = parser.parse_args() args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml)) args.config_yaml = get_abs_model(args.config_yaml) return args def main(args): paddle.seed(12345) # load config config = load_yaml(args.config_yaml) dy_model_class = load_dy_model_class(args.abs_dir) config["config_abs_dir"] = args.abs_dir # modify config from command if args.opt: for parameter in args.opt: parameter = parameter.strip() key, value = parameter.split("=") if type(config.get(key)) is int: value = int(value) if type(config.get(key)) is bool: value = (True if value.lower() == "true" else False) config[key] = value # tools.vars use_gpu = config.get("runner.use_gpu", True) train_data_dir = config.get("runner.train_data_dir", None) epochs = config.get("runner.epochs", None) print_interval = config.get("runner.print_interval", None) model_save_path = config.get("runner.model_save_path", "model_output") model_init_path = config.get("runner.model_init_path", None) end_epoch = config.get("runner.infer_end_epoch", 0) CE = config.get("runner.CE", False) logger.info("**************common.configs**********") logger.info( "use_gpu: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}". format(use_gpu, train_data_dir, epochs, print_interval, model_save_path)) logger.info("**************common.configs**********") place = paddle.set_device('gpu' if use_gpu else 'cpu') dy_model = dy_model_class.create_model(config) if not CE: model_save_path = os.path.join(model_save_path, str(end_epoch - 1)) load_model(model_init_path, dy_model) # example dnn model forward dy_model = paddle.jit.to_static( dy_model, input_spec=[[ paddle.static.InputSpec( shape=[None, 1], dtype='int64') for jj in range(26) ], paddle.static.InputSpec( shape=[None, 13], dtype='float32')]) save_jit_model(dy_model, model_save_path, prefix='tostatic') if __name__ == '__main__': args = parse_args() main(args)