| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 | 
							- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 
- #
 
- # Licensed under the Apache License, Version 2.0 (the "License");
 
- # you may not use this file except in compliance with the License.
 
- # You may obtain a copy of the License at
 
- #
 
- #     http://www.apache.org/licenses/LICENSE-2.0
 
- #
 
- # Unless required by applicable law or agreed to in writing, software
 
- # distributed under the License is distributed on an "AS IS" BASIS,
 
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 
- # See the License for the specific language governing permissions and
 
- # limitations under the License.
 
- import os
 
- os.environ['FLAGS_enable_pir_api'] = '0'
 
- import paddle
 
- import paddle.nn as nn
 
- import time
 
- import logging
 
- import sys
 
- import importlib
 
- __dir__ = os.path.dirname(os.path.abspath(__file__))
 
- #sys.path.append(__dir__)
 
- sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
 
- from utils.utils_single import load_yaml, load_dy_model_class, get_abs_model, create_data_loader
 
- from utils.save_load import load_model, save_model, save_jit_model
 
- from paddle.io import DistributedBatchSampler, DataLoader
 
- import argparse
 
- logging.basicConfig(
 
-     format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
 
- logger = logging.getLogger(__name__)
 
- def parse_args():
 
-     parser = argparse.ArgumentParser(description='paddle-rec run')
 
-     parser.add_argument("-m", "--config_yaml", type=str)
 
-     parser.add_argument("-o", "--opt", nargs='*', type=str)
 
-     args = parser.parse_args()
 
-     args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
 
-     args.config_yaml = get_abs_model(args.config_yaml)
 
-     return args
 
- def main(args):
 
-     paddle.seed(12345)
 
-     # load config
 
-     config = load_yaml(args.config_yaml)
 
-     dy_model_class = load_dy_model_class(args.abs_dir)
 
-     config["config_abs_dir"] = args.abs_dir
 
-     # modify config from command
 
-     if args.opt:
 
-         for parameter in args.opt:
 
-             parameter = parameter.strip()
 
-             key, value = parameter.split("=")
 
-             if type(config.get(key)) is int:
 
-                 value = int(value)
 
-             if type(config.get(key)) is bool:
 
-                 value = (True if value.lower() == "true" else False)
 
-             config[key] = value
 
-     # tools.vars
 
-     use_gpu = config.get("runner.use_gpu", True)
 
-     train_data_dir = config.get("runner.train_data_dir", None)
 
-     epochs = config.get("runner.epochs", None)
 
-     print_interval = config.get("runner.print_interval", None)
 
-     model_save_path = config.get("runner.model_save_path", "model_output")
 
-     model_init_path = config.get("runner.model_init_path", None)
 
-     end_epoch = config.get("runner.infer_end_epoch", 0)
 
-     CE = config.get("runner.CE", False)
 
-     logger.info("**************common.configs**********")
 
-     logger.info(
 
-         "use_gpu: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}".
 
-         format(use_gpu, train_data_dir, epochs, print_interval,
 
-                model_save_path))
 
-     logger.info("**************common.configs**********")
 
-     place = paddle.set_device('gpu' if use_gpu else 'cpu')
 
-     dy_model = dy_model_class.create_model(config)
 
-     if not CE:
 
-         model_save_path = os.path.join(model_save_path, str(end_epoch - 1))
 
-     load_model(model_init_path, dy_model)
 
-     # example dnn model forward
 
-     dy_model = paddle.jit.to_static(
 
-         dy_model,
 
-         input_spec=[[
 
-             paddle.static.InputSpec(
 
-                 shape=[None, 1], dtype='int64') for jj in range(26)
 
-         ], paddle.static.InputSpec(
 
-             shape=[None, 13], dtype='float32')])
 
-     save_jit_model(dy_model, model_save_path, prefix='tostatic')
 
- if __name__ == '__main__':
 
-     args = parse_args()
 
-     main(args)
 
 
  |