|
@@ -51,20 +51,20 @@ logger = logging.getLogger(__name__)
|
|
|
import json
|
|
|
|
|
|
class InferenceFetchHandler(FetchHandler):
|
|
|
- def __init__(self, var_dict, batch_size=1000):
|
|
|
+ def __init__(self, var_dict, output_file, batch_size=1000):
|
|
|
super().__init__(var_dict=var_dict, period_secs=1)
|
|
|
- # self.output_file = output_file
|
|
|
- # self.batch_size = batch_size
|
|
|
- # self.current_batch = []
|
|
|
- # self.total_samples = 0
|
|
|
+ self.output_file = output_file
|
|
|
+ self.batch_size = batch_size
|
|
|
+ self.current_batch = []
|
|
|
+ self.total_samples = 0
|
|
|
|
|
|
- # # 创建输出目录(如果不存在)
|
|
|
- # output_dir = os.path.dirname(output_file)
|
|
|
- # if not os.path.exists(output_dir):
|
|
|
- # os.makedirs(output_dir)
|
|
|
- # # 创建或清空输出文件
|
|
|
- # with open(self.output_file, 'w') as f:
|
|
|
- # f.write('')
|
|
|
+ # 创建输出目录(如果不存在)
|
|
|
+ output_dir = os.path.dirname(output_file)
|
|
|
+ if not os.path.exists(output_dir):
|
|
|
+ os.makedirs(output_dir)
|
|
|
+ # 创建或清空输出文件
|
|
|
+ with open(self.output_file, 'w') as f:
|
|
|
+ f.write('')
|
|
|
|
|
|
def handler(self, fetch_vars):
|
|
|
super().handler(res_dict=fetch_vars)
|