| 
					
				 | 
			
			
				@@ -5,7 +5,9 @@ from utils.oss_client import HangZhouOSSClient 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import utils.compress as compress 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 from utils.my_hdfs_client import MyHDFSClient 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import logging 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-import paddle 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+# 引用 paddle inference 推理库 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import paddle.inference as paddle_infer 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 logging.basicConfig( 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				     format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -19,3 +21,44 @@ configs = { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 hdfs_client = MyHDFSClient(hadoop_home, configs) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+def main(): 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    init_model_path = "/app/output_model_dssm" 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    client = HangZhouOSSClient("art-recommend") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    oss_object_name = self.config.get("runner.oss_object_name", "dyp/model.tar.gz") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    client.get_object_to_file(oss_object_name, "model.tar.gz") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    compress.uncompress_tar("model.tar.gz", init_model_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    assert os.path.exists(init_model_path) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    model_file=os.path.join(init_model_path, "dssm.pdmodel") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    params_file=os.path.join(init_model_path, "dssm.pdiparams") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 创建 config 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    config = paddle_infer.Config(model_file, params_file) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 根据 config 创建 predictor 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    predictor = paddle_infer.create_predictor(config) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 获取输入的名称 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    input_names = predictor.get_input_names() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    input_handle = predictor.get_input_handle(input_names[0]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 设置输入 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    fake_input = np.random.randn(1, 157).astype("float32") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    input_handle.reshape([1, 157]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    input_handle.copy_from_cpu(fake_input) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 运行predictor 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    predictor.run() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    # 获取输出 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    output_names = predictor.get_output_names() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    output_handle = predictor.get_output_handle(output_names[0]) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    output_data = output_handle.copy_to_cpu() # numpy.ndarray类型 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print("Output data size is {}".format(output_data.size)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    print("Output data shape is {}".format(output_data.shape)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+if __name__ == "__main__": 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    main() 
			 |