|
@@ -50,13 +50,8 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
|
import json
|
|
|
|
|
|
-class InferenceFetchHandler(FetchHandler):
|
|
|
- def __init__(self, var_dict, period_secs, output_file, batch_size=1000):
|
|
|
- assert var_dict is not None
|
|
|
- self.var_dict = var_dict
|
|
|
- self.period_secs = period_secs
|
|
|
-
|
|
|
-
|
|
|
+class InferenceFetchHandler(object):
|
|
|
+ def __init__(self, output_file, batch_size=1000):
|
|
|
self.output_file = output_file
|
|
|
self.batch_size = batch_size
|
|
|
self.current_batch = []
|
|
@@ -322,6 +317,7 @@ class Main(object):
|
|
|
|
|
|
# 创建处理器实例
|
|
|
fetch_handler = InferenceFetchHandler(output_file)
|
|
|
+ fetch_handler.set_var_dict(fetch_vars)
|
|
|
print(paddle.static.default_main_program()._fleet_opt)
|
|
|
self.exe.infer_from_dataset(
|
|
|
program=paddle.static.default_main_program(),
|