| 
					
				 | 
			
			
				@@ -70,22 +70,22 @@ class InferenceFetchHandler(FetchHandler): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         super().handler(res_dict=fetch_vars) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """处理每批次的推理结果""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         result_dict = {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        for var_name, var_value in fetch_vars.items(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            # 转换数据类型 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if isinstance(var_value, np.ndarray): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                result = var_value.tolist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                result = var_value 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            result_dict[var_name] = result 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # for var_name, var_value in fetch_vars.items(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #     # 转换数据类型 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #     if isinstance(var_value, np.ndarray): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #         result = var_value.tolist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #     else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #         result = var_value 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #     result_dict[var_name] = result 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				          
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.current_batch.append(result_dict) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.total_samples += len(result_dict.get(list(result_dict.keys())[0], [])) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # self.current_batch.append(result_dict) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # self.total_samples += len(result_dict.get(list(result_dict.keys())[0], [])) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				          
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # 当累积足够的结果时,写入文件 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if len(self.current_batch) >= self.batch_size: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            self._write_batch() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            logger.info(f"Saved {self.total_samples} samples to {self.output_file}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # # 当累积足够的结果时,写入文件 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # if len(self.current_batch) >= self.batch_size: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #     self._write_batch() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        #     logger.info(f"Saved {self.total_samples} samples to {self.output_file}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				      
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def _write_batch(self): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """将批次结果写入文件""" 
			 |