| 
					
				 | 
			
			
				@@ -69,23 +69,23 @@ class InferenceFetchHandler(FetchHandler): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def handler(self, fetch_vars): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         super().handler(res_dict=fetch_vars) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """处理每批次的推理结果""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # result_dict = {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # for var_name, var_value in fetch_vars.items(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #     # 转换数据类型 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #     if isinstance(var_value, np.ndarray): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #         result = var_value.tolist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #     else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #         result = var_value 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #     result_dict[var_name] = result 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        result_dict = {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for var_name, var_value in fetch_vars.items(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            # 转换数据类型 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if isinstance(var_value, np.ndarray): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                result = var_value.tolist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                result = var_value 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            result_dict[var_name] = result 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				          
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # self.current_batch.append(result_dict) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # self.total_samples += len(result_dict.get(list(result_dict.keys())[0], [])) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.current_batch.append(result_dict) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        self.total_samples += len(result_dict.get(list(result_dict.keys())[0], [])) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				          
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # # 当累积足够的结果时,写入文件 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # if len(self.current_batch) >= self.batch_size: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #     self._write_batch() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        #     logger.info(f"Saved {self.total_samples} samples to {self.output_file}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # 当累积足够的结果时,写入文件 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if len(self.current_batch) >= self.batch_size: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            self._write_batch() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            logger.info(f"Saved {self.total_samples} samples to {self.output_file}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				      
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _write_batch(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """将批次结果写入文件""" 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -302,7 +302,7 @@ class Main(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         output_file = os.path.join(output_dir, f"epoch_{epoch}_results.jsonl") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				          
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # 创建处理器实例 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        fetch_handler = InferenceFetchHandler(var_dict = self.metrics) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        fetch_handler = InferenceFetchHandler(var_dict = self.metrics, output_file = output_file) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # fetch_handler.set_var_dict(self.metrics) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         print(paddle.static.default_main_program()._fleet_opt) 
			 |