|
@@ -53,39 +53,39 @@ import json
|
|
|
class InferenceFetchHandler(FetchHandler):
|
|
|
def __init__(self, var_dict, output_file, batch_size=1000):
|
|
|
super().__init__(var_dict=var_dict, period_secs=1)
|
|
|
- self.output_file = output_file
|
|
|
- self.batch_size = batch_size
|
|
|
- self.current_batch = []
|
|
|
- self.total_samples = 0
|
|
|
+ # self.output_file = output_file
|
|
|
+ # self.batch_size = batch_size
|
|
|
+ # self.current_batch = []
|
|
|
+ # self.total_samples = 0
|
|
|
|
|
|
- # 创建输出目录(如果不存在)
|
|
|
- output_dir = os.path.dirname(output_file)
|
|
|
- if not os.path.exists(output_dir):
|
|
|
- os.makedirs(output_dir)
|
|
|
- # 创建或清空输出文件
|
|
|
- with open(self.output_file, 'w') as f:
|
|
|
- f.write('')
|
|
|
+ # # 创建输出目录(如果不存在)
|
|
|
+ # output_dir = os.path.dirname(output_file)
|
|
|
+ # if not os.path.exists(output_dir):
|
|
|
+ # os.makedirs(output_dir)
|
|
|
+ # # 创建或清空输出文件
|
|
|
+ # with open(self.output_file, 'w') as f:
|
|
|
+ # f.write('')
|
|
|
|
|
|
def handler(self, fetch_vars):
|
|
|
super().handler(res_dict=fetch_vars)
|
|
|
"""处理每批次的推理结果"""
|
|
|
- result_dict = {}
|
|
|
- logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars))
|
|
|
- for var_name, var_value in fetch_vars.items():
|
|
|
- # 转换数据类型
|
|
|
- if isinstance(var_value, np.ndarray):
|
|
|
- result = var_value.tolist()
|
|
|
- else:
|
|
|
- result = var_value
|
|
|
- result_dict[var_name] = result
|
|
|
+ # result_dict = {}
|
|
|
+ # logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars))
|
|
|
+ # for var_name, var_value in fetch_vars.items():
|
|
|
+ # # 转换数据类型
|
|
|
+ # if isinstance(var_value, np.ndarray):
|
|
|
+ # result = var_value.tolist()
|
|
|
+ # else:
|
|
|
+ # result = var_value
|
|
|
+ # result_dict[var_name] = result
|
|
|
|
|
|
- self.current_batch.append(result_dict)
|
|
|
- self.total_samples += len(result_dict.get(list(result_dict.keys())[0], []))
|
|
|
+ # self.current_batch.append(result_dict)
|
|
|
+ # self.total_samples += len(result_dict.get(list(result_dict.keys())[0], []))
|
|
|
|
|
|
- # 当累积足够的结果时,写入文件
|
|
|
- if len(self.current_batch) >= self.batch_size:
|
|
|
- self._write_batch()
|
|
|
- logger.info(f"Saved {self.total_samples} samples to {self.output_file}")
|
|
|
+ # # 当累积足够的结果时,写入文件
|
|
|
+ # if len(self.current_batch) >= self.batch_size:
|
|
|
+ # self._write_batch()
|
|
|
+ # logger.info(f"Saved {self.total_samples} samples to {self.output_file}")
|
|
|
|
|
|
def _write_batch(self):
|
|
|
"""将批次结果写入文件"""
|
|
@@ -302,7 +302,7 @@ class Main(object):
|
|
|
output_file = os.path.join(output_dir, f"epoch_{epoch}_results.jsonl")
|
|
|
|
|
|
# 创建处理器实例
|
|
|
- fetch_handler = InferenceFetchHandler(var_dict = self.metrics, output_file =output_file)
|
|
|
+ fetch_handler = InferenceFetchHandler(var_dict = self.metrics)
|
|
|
# fetch_handler.set_var_dict(self.metrics)
|
|
|
|
|
|
print(paddle.static.default_main_program()._fleet_opt)
|