|
@@ -50,23 +50,18 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
|
import json
|
|
|
|
|
|
-class InferenceFetchHandler(object):
|
|
|
- def __init__(self, output_file, batch_size=1000):
|
|
|
+class InferenceFetchHandler(FetchHandler):
|
|
|
+ def __init__(self, var_dict, output_file, batch_size=1000):
|
|
|
+ super().__init__(var_dict=var_dict, period_secs=1)
|
|
|
self.output_file = output_file
|
|
|
self.batch_size = batch_size
|
|
|
self.current_batch = []
|
|
|
self.total_samples = 0
|
|
|
|
|
|
- # 添加 Paddle 需要的属性
|
|
|
- self.period_secs = 1 # 设置默认的周期时间(秒)
|
|
|
- self.done_event = threading.Event() # 添加完成事件
|
|
|
- self.terminal_event = threading.Event() # 添加终止事件
|
|
|
-
|
|
|
# 创建输出目录(如果不存在)
|
|
|
output_dir = os.path.dirname(output_file)
|
|
|
if not os.path.exists(output_dir):
|
|
|
os.makedirs(output_dir)
|
|
|
- self.var_dict = {} # 用于存储需要获取的变量
|
|
|
# 创建或清空输出文件
|
|
|
with open(self.output_file, 'w') as f:
|
|
|
f.write('')
|
|
@@ -106,16 +101,6 @@ class InferenceFetchHandler(object):
|
|
|
logger.info(f"Final save: total {self.total_samples} samples saved to {self.output_file}")
|
|
|
self.done_event.set()
|
|
|
|
|
|
- def stop(self):
|
|
|
- """停止处理"""
|
|
|
- self.terminal_event.set()
|
|
|
- self.finish()
|
|
|
- def set_var_dict(self, var_dict):
|
|
|
- """设置需要获取的变量字典"""
|
|
|
- self.var_dict = var_dict
|
|
|
- def wait(self):
|
|
|
- """等待处理完成"""
|
|
|
- self.done_event.wait()
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
@@ -318,7 +303,7 @@ class Main(object):
|
|
|
# 创建处理器实例
|
|
|
fetch_handler = InferenceFetchHandler(output_file)
|
|
|
fetch_handler.set_var_dict(self.metrics)
|
|
|
-
|
|
|
+
|
|
|
print(paddle.static.default_main_program()._fleet_opt)
|
|
|
self.exe.infer_from_dataset(
|
|
|
program=paddle.static.default_main_program(),
|