|  | @@ -53,39 +53,39 @@ import json
 | 
	
		
			
				|  |  |  class InferenceFetchHandler(FetchHandler):
 | 
	
		
			
				|  |  |      def __init__(self, var_dict, output_file, batch_size=1000):
 | 
	
		
			
				|  |  |          super().__init__(var_dict=var_dict, period_secs=1)
 | 
	
		
			
				|  |  | -        self.output_file = output_file
 | 
	
		
			
				|  |  | -        self.batch_size = batch_size
 | 
	
		
			
				|  |  | -        self.current_batch = []
 | 
	
		
			
				|  |  | -        self.total_samples = 0
 | 
	
		
			
				|  |  | +        # self.output_file = output_file
 | 
	
		
			
				|  |  | +        # self.batch_size = batch_size
 | 
	
		
			
				|  |  | +        # self.current_batch = []
 | 
	
		
			
				|  |  | +        # self.total_samples = 0
 | 
	
		
			
				|  |  |          
 | 
	
		
			
				|  |  | -        # 创建输出目录(如果不存在)
 | 
	
		
			
				|  |  | -        output_dir = os.path.dirname(output_file)
 | 
	
		
			
				|  |  | -        if not os.path.exists(output_dir):
 | 
	
		
			
				|  |  | -            os.makedirs(output_dir)
 | 
	
		
			
				|  |  | -        # 创建或清空输出文件
 | 
	
		
			
				|  |  | -        with open(self.output_file, 'w') as f:
 | 
	
		
			
				|  |  | -            f.write('')
 | 
	
		
			
				|  |  | +        # # 创建输出目录(如果不存在)
 | 
	
		
			
				|  |  | +        # output_dir = os.path.dirname(output_file)
 | 
	
		
			
				|  |  | +        # if not os.path.exists(output_dir):
 | 
	
		
			
				|  |  | +        #     os.makedirs(output_dir)
 | 
	
		
			
				|  |  | +        # # 创建或清空输出文件
 | 
	
		
			
				|  |  | +        # with open(self.output_file, 'w') as f:
 | 
	
		
			
				|  |  | +        #     f.write('')
 | 
	
		
			
				|  |  |      
 | 
	
		
			
				|  |  |      def handler(self, fetch_vars):
 | 
	
		
			
				|  |  |          super().handler(res_dict=fetch_vars)
 | 
	
		
			
				|  |  |          """处理每批次的推理结果"""
 | 
	
		
			
				|  |  | -        result_dict = {}
 | 
	
		
			
				|  |  | -        logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars))
 | 
	
		
			
				|  |  | -        for var_name, var_value in fetch_vars.items():
 | 
	
		
			
				|  |  | -            # 转换数据类型
 | 
	
		
			
				|  |  | -            if isinstance(var_value, np.ndarray):
 | 
	
		
			
				|  |  | -                result = var_value.tolist()
 | 
	
		
			
				|  |  | -            else:
 | 
	
		
			
				|  |  | -                result = var_value
 | 
	
		
			
				|  |  | -            result_dict[var_name] = result
 | 
	
		
			
				|  |  | +        # result_dict = {}
 | 
	
		
			
				|  |  | +        # logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars))
 | 
	
		
			
				|  |  | +        # for var_name, var_value in fetch_vars.items():
 | 
	
		
			
				|  |  | +        #     # 转换数据类型
 | 
	
		
			
				|  |  | +        #     if isinstance(var_value, np.ndarray):
 | 
	
		
			
				|  |  | +        #         result = var_value.tolist()
 | 
	
		
			
				|  |  | +        #     else:
 | 
	
		
			
				|  |  | +        #         result = var_value
 | 
	
		
			
				|  |  | +        #     result_dict[var_name] = result
 | 
	
		
			
				|  |  |          
 | 
	
		
			
				|  |  | -        self.current_batch.append(result_dict)
 | 
	
		
			
				|  |  | -        self.total_samples += len(result_dict.get(list(result_dict.keys())[0], []))
 | 
	
		
			
				|  |  | +        # self.current_batch.append(result_dict)
 | 
	
		
			
				|  |  | +        # self.total_samples += len(result_dict.get(list(result_dict.keys())[0], []))
 | 
	
		
			
				|  |  |          
 | 
	
		
			
				|  |  | -        # 当累积足够的结果时,写入文件
 | 
	
		
			
				|  |  | -        if len(self.current_batch) >= self.batch_size:
 | 
	
		
			
				|  |  | -            self._write_batch()
 | 
	
		
			
				|  |  | -            logger.info(f"Saved {self.total_samples} samples to {self.output_file}")
 | 
	
		
			
				|  |  | +        # # 当累积足够的结果时,写入文件
 | 
	
		
			
				|  |  | +        # if len(self.current_batch) >= self.batch_size:
 | 
	
		
			
				|  |  | +        #     self._write_batch()
 | 
	
		
			
				|  |  | +        #     logger.info(f"Saved {self.total_samples} samples to {self.output_file}")
 | 
	
		
			
				|  |  |      
 | 
	
		
			
				|  |  |      def _write_batch(self):
 | 
	
		
			
				|  |  |          """将批次结果写入文件"""
 | 
	
	
		
			
				|  | @@ -302,7 +302,7 @@ class Main(object):
 | 
	
		
			
				|  |  |          output_file = os.path.join(output_dir, f"epoch_{epoch}_results.jsonl")
 | 
	
		
			
				|  |  |          
 | 
	
		
			
				|  |  |          # 创建处理器实例
 | 
	
		
			
				|  |  | -        fetch_handler = InferenceFetchHandler(var_dict = self.metrics, output_file =output_file)
 | 
	
		
			
				|  |  | +        fetch_handler = InferenceFetchHandler(var_dict = self.metrics)
 | 
	
		
			
				|  |  |          # fetch_handler.set_var_dict(self.metrics)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          print(paddle.static.default_main_program()._fleet_opt)
 |