to_static.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. os.environ['FLAGS_enable_pir_api'] = '0'
  16. import paddle
  17. import paddle.nn as nn
  18. import time
  19. import logging
  20. import sys
  21. import importlib
  22. __dir__ = os.path.dirname(os.path.abspath(__file__))
  23. #sys.path.append(__dir__)
  24. sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
  25. from utils.utils_single import load_yaml, load_dy_model_class, get_abs_model, create_data_loader
  26. from utils.save_load import load_model, save_model, save_jit_model
  27. from paddle.io import DistributedBatchSampler, DataLoader
  28. import argparse
  29. logging.basicConfig(
  30. format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
  31. logger = logging.getLogger(__name__)
  32. def parse_args():
  33. parser = argparse.ArgumentParser(description='paddle-rec run')
  34. parser.add_argument("-m", "--config_yaml", type=str)
  35. parser.add_argument("-o", "--opt", nargs='*', type=str)
  36. args = parser.parse_args()
  37. args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
  38. args.config_yaml = get_abs_model(args.config_yaml)
  39. return args
  40. def main(args):
  41. paddle.seed(12345)
  42. # load config
  43. config = load_yaml(args.config_yaml)
  44. dy_model_class = load_dy_model_class(args.abs_dir)
  45. config["config_abs_dir"] = args.abs_dir
  46. # modify config from command
  47. if args.opt:
  48. for parameter in args.opt:
  49. parameter = parameter.strip()
  50. key, value = parameter.split("=")
  51. if type(config.get(key)) is int:
  52. value = int(value)
  53. if type(config.get(key)) is bool:
  54. value = (True if value.lower() == "true" else False)
  55. config[key] = value
  56. # tools.vars
  57. use_gpu = config.get("runner.use_gpu", True)
  58. train_data_dir = config.get("runner.train_data_dir", None)
  59. epochs = config.get("runner.epochs", None)
  60. print_interval = config.get("runner.print_interval", None)
  61. model_save_path = config.get("runner.model_save_path", "model_output")
  62. model_init_path = config.get("runner.model_init_path", None)
  63. end_epoch = config.get("runner.infer_end_epoch", 0)
  64. CE = config.get("runner.CE", False)
  65. logger.info("**************common.configs**********")
  66. logger.info(
  67. "use_gpu: {}, train_data_dir: {}, epochs: {}, print_interval: {}, model_save_path: {}".
  68. format(use_gpu, train_data_dir, epochs, print_interval,
  69. model_save_path))
  70. logger.info("**************common.configs**********")
  71. place = paddle.set_device('gpu' if use_gpu else 'cpu')
  72. dy_model = dy_model_class.create_model(config)
  73. if not CE:
  74. model_save_path = os.path.join(model_save_path, str(end_epoch - 1))
  75. load_model(model_init_path, dy_model)
  76. # example dnn model forward
  77. dy_model = paddle.jit.to_static(
  78. dy_model,
  79. input_spec=[[
  80. paddle.static.InputSpec(
  81. shape=[None, 1], dtype='int64') for jj in range(26)
  82. ], paddle.static.InputSpec(
  83. shape=[None, 13], dtype='float32')])
  84. save_jit_model(dy_model, model_save_path, prefix='tostatic')
  85. if __name__ == '__main__':
  86. args = parse_args()
  87. main(args)