| 
					
				 | 
			
			
				@@ -16,7 +16,7 @@ import paddle 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import os 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import logging 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import numpy as np 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-from paddle.static.io import _get_valid_program, normalize_program 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+from paddle.static.io import _get_valid_program, normalize_program, program_guard 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 logging.basicConfig( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -75,9 +75,31 @@ def save_inference_model(model_path, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     model_path = os.path.join(model_path, str(epoch_id)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     _mkdir_if_not_exist(model_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     model_prefix = os.path.join(model_path, prefix) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    fetch_vars2 = fetch_vars 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     program = _get_valid_program(None) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    for op in program.global_block().ops: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        # clear device of Op 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        op._set_attr(device_attr_name, "") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        if op.type == 'auc': 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            warnings.warn( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                "Be sure that you have set auc states to 0 before saving inference model." 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            break 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # fix the bug that the activation op's output as target will be pruned. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # will affect the inference performance. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # TODO(Superjomn) add an IR pass to remove 1-scale op. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    with program_guard(program): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        uniq_fetch_vars = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        for i, var in enumerate(fetch_vars): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            if var.dtype != paddle.bool: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                var = paddle.scale(var, 1.0, name=f"save_infer_model/scale_{i}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            uniq_fetch_vars.append(var) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        fetch_vars = uniq_fetch_vars 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     copy_program = program.clone() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     logger.info("program.clone(): {}".format(copy_program.global_block().vars.keys())) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -122,7 +144,7 @@ def save_inference_model(model_path, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     paddle.static.save_inference_model( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         path_prefix=model_prefix, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         feed_vars=feed_vars, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        fetch_vars=fetch_vars, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        fetch_vars=fetch_vars2, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         executor=exe) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return model_prefix 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 |