| 
					
				 | 
			
			
				@@ -50,13 +50,8 @@ logger = logging.getLogger(__name__) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import json 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-class InferenceFetchHandler(FetchHandler): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    def __init__(self, var_dict, period_secs, output_file, batch_size=1000): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        assert var_dict is not None 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.var_dict = var_dict 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        self.period_secs = period_secs 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+class InferenceFetchHandler(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    def __init__(self, output_file, batch_size=1000): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.output_file = output_file 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.batch_size = batch_size 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.current_batch = [] 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -322,6 +317,7 @@ class Main(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				          
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         # 创建处理器实例 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         fetch_handler = InferenceFetchHandler(output_file) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        fetch_handler.set_var_dict(fetch_vars) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         print(paddle.static.default_main_program()._fleet_opt) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         self.exe.infer_from_dataset( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             program=paddle.static.default_main_program(), 
			 |