|
@@ -71,21 +71,19 @@ class InferenceFetchHandler(FetchHandler):
|
|
|
"""处理每批次的推理结果"""
|
|
|
result_dict = {}
|
|
|
# logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars))
|
|
|
- # for var_name, var_value in fetch_vars.items():
|
|
|
- # # 转换数据类型
|
|
|
- # if isinstance(var_value, np.ndarray):
|
|
|
- # result = var_value.tolist()
|
|
|
- # else:
|
|
|
- # result = var_value
|
|
|
- # result_dict[var_name] = result
|
|
|
+ for var_name, var_value in fetch_vars.items():
|
|
|
+ # 转换数据类型
|
|
|
+ if isinstance(var_value, np.ndarray):
|
|
|
+ result = var_value.tolist()
|
|
|
+ else:
|
|
|
+ result = var_value
|
|
|
+ result_dict[var_name] = result
|
|
|
|
|
|
- # self.current_batch.append(result_dict)
|
|
|
- # self.total_samples += len(result_dict.get(list(result_dict.keys())[0], []))
|
|
|
+ self.current_batch.append(result_dict)
|
|
|
|
|
|
# # 当累积足够的结果时,写入文件
|
|
|
- # if len(self.current_batch) >= self.batch_size:
|
|
|
- # self._write_batch()
|
|
|
- # logger.info(f"Saved {self.total_samples} samples to {self.output_file}")
|
|
|
+ if len(self.current_batch) >= self.batch_size:
|
|
|
+ self._write_batch()
|
|
|
|
|
|
def _write_batch(self):
|
|
|
"""将批次结果写入文件"""
|
|
@@ -314,7 +312,7 @@ class Main(object):
|
|
|
print_period=print_step,
|
|
|
debug=debug,
|
|
|
fetch_handler=fetch_handler)
|
|
|
-
|
|
|
+ fetch_handler.finish()
|
|
|
|
|
|
|
|
|
def heter_train_loop(self, epoch):
|