|
@@ -48,7 +48,6 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
|
import json
|
|
|
|
|
|
-# 创建推理结果处理器类
|
|
|
class InferenceFetchHandler(object):
|
|
|
def __init__(self, output_file, batch_size=1000):
|
|
|
self.output_file = output_file
|
|
@@ -56,6 +55,11 @@ class InferenceFetchHandler(object):
|
|
|
self.current_batch = []
|
|
|
self.total_samples = 0
|
|
|
|
|
|
+ # 添加 Paddle 需要的属性
|
|
|
+ self.period_secs = 60 # 设置默认的周期时间(秒)
|
|
|
+ self.done_event = threading.Event() # 添加完成事件
|
|
|
+ self.terminal_event = threading.Event() # 添加终止事件
|
|
|
+
|
|
|
# 创建输出目录(如果不存在)
|
|
|
output_dir = os.path.dirname(output_file)
|
|
|
if not os.path.exists(output_dir):
|
|
@@ -96,6 +100,16 @@ class InferenceFetchHandler(object):
|
|
|
if self.current_batch:
|
|
|
self._write_batch()
|
|
|
logger.info(f"Final save: total {self.total_samples} samples saved to {self.output_file}")
|
|
|
+ self.done_event.set()
|
|
|
+
|
|
|
+ def stop(self):
|
|
|
+ """停止处理"""
|
|
|
+ self.terminal_event.set()
|
|
|
+ self.finish()
|
|
|
+
|
|
|
+ def wait(self):
|
|
|
+ """等待处理完成"""
|
|
|
+ self.done_event.wait()
|
|
|
|
|
|
|
|
|
def parse_args():
|