| 
					
				 | 
			
			
				@@ -67,18 +67,21 @@ class InferenceFetchHandler(FetchHandler): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             f.write('') 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				      
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     def handler(self, fetch_vars): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        super().handler(res_dict=fetch_vars) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         """处理每批次的推理结果""" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         result_dict = {} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # logger.info("InferenceFetchHandler fetch_vars {}".format(fetch_vars)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        for var_name, var_value in fetch_vars.items(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        sys.stdout.write("\n") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for key in fetch_vars: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             # 转换数据类型 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if isinstance(var_value, np.ndarray): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                result = var_value.tolist() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if type(fetch_vars[key]) is np.ndarray: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                result = res_dict[key][0] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             else: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 result = var_value 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            result_dict[var_name] = result 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-         
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            result_dict[key] = result 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+             
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for key in fetch_vars: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if type(fetch_vars[key]) is np.ndarray: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                sys.stdout.write(f"{key}[0]: {fetch_vars[key][0]} ") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.current_batch.append(result_dict) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				          
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # # 当累积足够的结果时,写入文件 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -301,7 +304,6 @@ class Main(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				          
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # 创建处理器实例 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         fetch_handler = InferenceFetchHandler(var_dict = self.metrics, output_file = output_file) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        # fetch_handler.set_var_dict(self.metrics) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         print(paddle.static.default_main_program()._fleet_opt) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.exe.infer_from_dataset( 
			 |