| 
														
															@@ -18,6 +18,7 @@ import logging 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 import numpy as np 
														 | 
														
														 | 
														
															 import numpy as np 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 from paddle.static.io import _get_valid_program, normalize_program, program_guard 
														 | 
														
														 | 
														
															 from paddle.static.io import _get_valid_program, normalize_program, program_guard 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 logging.basicConfig( 
														 | 
														
														 | 
														
															 logging.basicConfig( 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) 
														 | 
														
														 | 
														
															     format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 logger = logging.getLogger(__name__) 
														 | 
														
														 | 
														
															 logger = logging.getLogger(__name__) 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -79,16 +80,6 @@ def save_inference_model(model_path, 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     program = _get_valid_program(None) 
														 | 
														
														 | 
														
															     program = _get_valid_program(None) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															-    for op in program.global_block().ops: 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        # clear device of Op 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        op._set_attr(device_attr_name, "") 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        if op.type == 'auc': 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            warnings.warn( 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-                "Be sure that you have set auc states to 0 before saving inference model." 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            ) 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-            break 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															- 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     # fix the bug that the activation op's output as target will be pruned. 
														 | 
														
														 | 
														
															     # fix the bug that the activation op's output as target will be pruned. 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     # will affect the inference performance. 
														 | 
														
														 | 
														
															     # will affect the inference performance. 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     # TODO(Superjomn) add an IR pass to remove 1-scale op. 
														 | 
														
														 | 
														
															     # TODO(Superjomn) add an IR pass to remove 1-scale op. 
														 |