|
@@ -73,6 +73,7 @@ class InferenceFetchHandler(object):
|
|
|
def handler(self, fetch_vars):
|
|
|
"""处理每批次的推理结果"""
|
|
|
result_dict = {}
|
|
|
+ print("InferenceFetchHandler fetch_vars {}".format(fetch_vars))
|
|
|
for var_name, var_value in fetch_vars.items():
|
|
|
# 转换数据类型
|
|
|
if isinstance(var_value, np.ndarray):
|
|
@@ -315,7 +316,7 @@ class Main(object):
|
|
|
# 创建处理器实例
|
|
|
fetch_handler = InferenceFetchHandler(output_file)
|
|
|
print(paddle.static.default_main_program()._fleet_opt)
|
|
|
- results = self.exe.infer_from_dataset(
|
|
|
+ self.exe.infer_from_dataset(
|
|
|
program=paddle.static.default_main_program(),
|
|
|
dataset=self.reader,
|
|
|
fetch_list=fetch_vars,
|
|
@@ -323,7 +324,6 @@ class Main(object):
|
|
|
print_period=print_step,
|
|
|
debug=debug,
|
|
|
fetch_handler=fetch_handler)
|
|
|
- print("results {}".format(results))
|
|
|
fetch_handler.finish()
|
|
|
|
|
|
|