| 
					
				 | 
			
			
				@@ -76,66 +76,11 @@ def save_inference_model(model_path, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     model_path = os.path.join(model_path, str(epoch_id)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     _mkdir_if_not_exist(model_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     model_prefix = os.path.join(model_path, prefix) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    fetch_vars2 = fetch_vars 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    program = _get_valid_program(None) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # fix the bug that the activation op's output as target will be pruned. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # will affect the inference performance. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    # TODO(Superjomn) add an IR pass to remove 1-scale op. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    with program_guard(program): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        uniq_fetch_vars = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        for i, var in enumerate(fetch_vars): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if var.dtype != paddle.bool: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                var = paddle.scale(var, 1.0, name=f"save_infer_model/scale_{i}") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            uniq_fetch_vars.append(var) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        fetch_vars = uniq_fetch_vars 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    copy_program = program.clone() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    logger.info("program.clone(): {}".format(copy_program.global_block().vars.keys())) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    global_block = copy_program.global_block() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    remove_op_idx = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    for i, op in enumerate(global_block.ops): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        op.desc.set_is_target(False) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if op.type == "feed" or op.type == "fetch": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            remove_op_idx.append(i) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        if op.type == "pylayer": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            sub_blocks_ids = op._blocks_attr_ids("blocks") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            if len(sub_blocks_ids) > 1: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                # pylayer op ``blocks`` attr contains forward block id and backward block id 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                backward_block_id = sub_blocks_ids[-1] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                # remove backward block 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                copy_program.blocks.pop(backward_block_id) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                # update attrs ``blocks`` 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                reserverd_blocks = [] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                for block_id in sub_blocks_ids[:-1]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                    reserverd_blocks.append(copy_program.block(block_id)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-                op._update_desc_attr("blocks", reserverd_blocks) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    for idx in remove_op_idx[::-1]: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        global_block._remove_op(idx) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    copy_program.desc.flush() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    logger.info("copy_program.desc.flush(): {}".format(copy_program.global_block().vars.keys())) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    feed_var_names = [var.name for var in feed_vars] 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    skip_prune_program = False 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    if not skip_prune_program: 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        copy_program = copy_program._prune_with_input( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-            feeded_var_names=feed_var_names, targets=fetch_vars 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        ) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        logger.info("copy_program._prune_with_input(): {}".format(copy_program.global_block().vars.keys())) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    copy_program = copy_program._inference_optimize(prune_read_op=True) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-    logger.info("copy_program._inference_optimize(): {}".format(copy_program.global_block().vars.keys())) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-     
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     paddle.static.save_inference_model( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         path_prefix=model_prefix, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         feed_vars=feed_vars, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-        fetch_vars=fetch_vars2, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        fetch_vars=fetch_vars, 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				         executor=exe) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     return model_prefix 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 |