|
@@ -67,18 +67,21 @@ class InferenceFetchHandler(FetchHandler):
|
|
|
f.write('')
|
|
|
|
|
|
def handler(self, fetch_vars):
|
|
|
- super().handler(res_dict=fetch_vars)
|
|
|
"""处理每批次的推理结果"""
|
|
|
result_dict = {}
|
|
|
- # logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars))
|
|
|
- for var_name, var_value in fetch_vars.items():
|
|
|
+ sys.stdout.write("\n")
|
|
|
+ for key in fetch_vars:
|
|
|
# 转换数据类型
|
|
|
- if isinstance(var_value, np.ndarray):
|
|
|
- result = var_value.tolist()
|
|
|
+ if type(fetch_vars[key]) is np.ndarray:
|
|
|
+ result = res_dict[key][0]
|
|
|
else:
|
|
|
result = var_value
|
|
|
- result_dict[var_name] = result
|
|
|
-
|
|
|
+ result_dict[key] = result
|
|
|
+
|
|
|
+ for key in fetch_vars:
|
|
|
+ if type(fetch_vars[key]) is np.ndarray:
|
|
|
+ sys.stdout.write(f"{key}[0]: {fetch_vars[key][0]} ")
|
|
|
+
|
|
|
self.current_batch.append(result_dict)
|
|
|
|
|
|
# # 当累积足够的结果时,写入文件
|
|
@@ -301,7 +304,6 @@ class Main(object):
|
|
|
|
|
|
# 创建处理器实例
|
|
|
fetch_handler = InferenceFetchHandler(var_dict = self.metrics, output_file = output_file)
|
|
|
- # fetch_handler.set_var_dict(self.metrics)
|
|
|
|
|
|
print(paddle.static.default_main_program()._fleet_opt)
|
|
|
self.exe.infer_from_dataset(
|