| 
					
				 | 
			
			
				@@ -220,28 +220,7 @@ class Main(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             self.train_result_dict["speed"].append(epoch_speed) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             model_dir = "{}/{}".format(save_model_path, epoch) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if paddle.distributed.get_rank() == 0: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                # 1. 确保所有 worker 同步 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                fleet.barrier_worker() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                # 2. 获取主程序 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                main_program = paddle.static.default_main_program() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                # 3. 使用 paddle.static.save_inference_model 替代 fleet.save_inference_model 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                paddle.static.save_inference_model( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    model_dir, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    [feed.name for feed in self.inference_feed_var], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    self.inference_target_var, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    self.exe, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    program=main_program,  # 使用 fleet 的主程序 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    export_for_deployment=True  # 保存为新格式 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                # 4. 再次同步确保保存完成 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                fleet.barrier_worker() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             if is_distributed_env(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                 fleet.save_inference_model( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     self.exe, model_dir, 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -253,7 +232,7 @@ class Main(object): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     model_dir, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     [feed.name for feed in self.inference_feed_var], 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				                     [self.inference_target_var], self.exe) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            """ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         if reader_type == "InmemoryDataset": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				             self.reader.release_memory() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 |