static_infer.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import print_function
  15. import os
  16. os.environ['FLAGS_enable_pir_api'] = '0'
  17. import warnings
  18. import logging
  19. import paddle
  20. import sys
  21. __dir__ = os.path.dirname(os.path.abspath(__file__))
  22. #sys.path.append(__dir__)
  23. sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
  24. from utils.utils_single import load_yaml, load_static_model_class, get_abs_model, create_data_loader, reset_auc
  25. from utils.save_load import save_static_model, load_static_model, save_data
  26. import time
  27. import argparse
  28. logging.basicConfig(
  29. format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
  30. logger = logging.getLogger(__name__)
  31. def parse_args():
  32. parser = argparse.ArgumentParser("PaddleRec train static script")
  33. parser.add_argument("-m", "--config_yaml", type=str)
  34. parser.add_argument("-o", "--opt", nargs='*', type=str)
  35. args = parser.parse_args()
  36. args.abs_dir = os.path.dirname(os.path.abspath(args.config_yaml))
  37. args.config_yaml = get_abs_model(args.config_yaml)
  38. return args
  39. def main(args):
  40. paddle.seed(12345)
  41. # load config
  42. config = load_yaml(args.config_yaml)
  43. config["config_abs_dir"] = args.abs_dir
  44. # modify config from command
  45. if args.opt:
  46. for parameter in args.opt:
  47. parameter = parameter.strip()
  48. key, value = parameter.split("=")
  49. if type(config.get(key)) is int:
  50. value = int(value)
  51. if type(config.get(key)) is float:
  52. value = float(value)
  53. if type(config.get(key)) is bool:
  54. value = (True if value.lower() == "true" else False)
  55. config[key] = value
  56. # load static model class
  57. static_model_class = load_static_model_class(config)
  58. input_data = static_model_class.create_feeds(is_infer=True)
  59. input_data_names = [data.name for data in input_data]
  60. fetch_vars = static_model_class.infer_net(input_data)
  61. logger.info("cpu_num: {}".format(os.getenv("CPU_NUM")))
  62. use_gpu = config.get("runner.use_gpu", True)
  63. use_xpu = config.get("runner.use_xpu", False)
  64. use_auc = config.get("runner.use_auc", False)
  65. use_visual = config.get("runner.use_visual", False)
  66. auc_num = config.get("runner.auc_num", 1)
  67. test_data_dir = config.get("runner.test_data_dir", None)
  68. print_interval = config.get("runner.print_interval", None)
  69. model_load_path = config.get("runner.infer_load_path", "model_output")
  70. start_epoch = config.get("runner.infer_start_epoch", 0)
  71. end_epoch = config.get("runner.infer_end_epoch", 10)
  72. batch_size = config.get("runner.infer_batch_size", None)
  73. use_save_data = config.get("runner.use_save_data", False)
  74. reader_type = config.get("runner.reader_type", "DataLoader")
  75. use_fleet = config.get("runner.use_fleet", False)
  76. os.environ["CPU_NUM"] = str(config.get("runner.thread_num", 1))
  77. logger.info("**************common.configs**********")
  78. logger.info(
  79. "use_gpu: {}, use_xpu: {}, use_visual: {}, infer_batch_size: {}, test_data_dir: {}, start_epoch: {}, end_epoch: {}, print_interval: {}, model_load_path: {}".
  80. format(use_gpu, use_xpu, use_visual, batch_size, test_data_dir,
  81. start_epoch, end_epoch, print_interval, model_load_path))
  82. logger.info("**************common.configs**********")
  83. if use_xpu:
  84. xpu_device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
  85. place = paddle.set_device(xpu_device)
  86. else:
  87. place = paddle.set_device('gpu' if use_gpu else 'cpu')
  88. exe = paddle.static.Executor(place)
  89. # initialize
  90. exe.run(paddle.static.default_startup_program())
  91. if reader_type == 'DataLoader':
  92. test_dataloader = create_data_loader(
  93. config=config, place=place, mode="test")
  94. elif reader_type == "CustomizeDataLoader":
  95. test_dataloader = static_model_class.create_data_loader()
  96. # Create a log_visual object and store the data in the path
  97. if use_visual:
  98. from visualdl import LogWriter
  99. log_visual = LogWriter(args.abs_dir + "/visualDL_log/infer")
  100. step_num = 0
  101. for epoch_id in range(start_epoch, end_epoch):
  102. logger.info("load model epoch {}".format(epoch_id))
  103. model_path = os.path.join(model_load_path, str(epoch_id))
  104. load_static_model(
  105. paddle.static.default_main_program(),
  106. model_path,
  107. prefix='rec_static')
  108. epoch_begin = time.time()
  109. interval_begin = time.time()
  110. infer_reader_cost = 0.0
  111. infer_run_cost = 0.0
  112. reader_start = time.time()
  113. if use_auc:
  114. reset_auc(use_fleet, auc_num)
  115. #we will drop the last incomplete batch when dataset size is not divisible by the batch size
  116. assert any(test_dataloader(
  117. )), "test_dataloader's size is null, please ensure batch size < dataset size!"
  118. for batch_id, batch_data in enumerate(test_dataloader()):
  119. infer_reader_cost += time.time() - reader_start
  120. infer_start = time.time()
  121. fetch_batch_var = exe.run(
  122. program=paddle.static.default_main_program(),
  123. feed=dict(zip(input_data_names, batch_data)),
  124. fetch_list=[var for _, var in fetch_vars.items()])
  125. infer_run_cost += time.time() - infer_start
  126. if batch_id % print_interval == 0:
  127. metric_str = ""
  128. for var_idx, var_name in enumerate(fetch_vars):
  129. metric_str += "{}: {}, ".format(var_name,
  130. fetch_batch_var[var_idx])
  131. if use_visual:
  132. log_visual.add_scalar(
  133. tag="infer/" + var_name,
  134. step=step_num,
  135. value=fetch_batch_var[var_idx])
  136. logger.info(
  137. "epoch: {}, batch_id: {}, ".format(epoch_id,
  138. batch_id) + metric_str +
  139. "avg_reader_cost: {:.5f} sec, avg_batch_cost: {:.5f} sec, avg_samples: {:.5f}, ips: {:.2f} ins/s".
  140. format(infer_reader_cost / print_interval, (
  141. infer_reader_cost + infer_run_cost) / print_interval,
  142. batch_size, print_interval * batch_size / (
  143. time.time() + 0.0001 - interval_begin)))
  144. interval_begin = time.time()
  145. infer_reader_cost = 0.0
  146. infer_run_cost = 0.0
  147. reader_start = time.time()
  148. step_num = step_num + 1
  149. metric_str = ""
  150. for var_idx, var_name in enumerate(fetch_vars):
  151. metric_str += "{}: {}, ".format(var_name, fetch_batch_var[var_idx])
  152. logger.info("epoch: {} done, ".format(epoch_id) + metric_str +
  153. "epoch time: {:.2f} s".format(time.time() - epoch_begin))
  154. if use_save_data:
  155. save_data(fetch_batch_var, model_load_path)
  156. if __name__ == "__main__":
  157. paddle.enable_static()
  158. args = parse_args()
  159. main(args)