|
@@ -5,7 +5,9 @@ from utils.oss_client import HangZhouOSSClient
|
|
import utils.compress as compress
|
|
import utils.compress as compress
|
|
from utils.my_hdfs_client import MyHDFSClient
|
|
from utils.my_hdfs_client import MyHDFSClient
|
|
import logging
|
|
import logging
|
|
-import paddle
|
|
|
|
|
|
+# 引用 paddle inference 推理库
|
|
|
|
+import paddle.inference as paddle_infer
|
|
|
|
+
|
|
|
|
|
|
logging.basicConfig(
|
|
logging.basicConfig(
|
|
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
|
|
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
|
|
@@ -19,3 +21,44 @@ configs = {
|
|
hdfs_client = MyHDFSClient(hadoop_home, configs)
|
|
hdfs_client = MyHDFSClient(hadoop_home, configs)
|
|
|
|
|
|
|
|
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def main():
|
|
|
|
+
|
|
|
|
+ init_model_path = "/app/output_model_dssm"
|
|
|
|
+ client = HangZhouOSSClient("art-recommend")
|
|
|
|
+ oss_object_name = self.config.get("runner.oss_object_name", "dyp/model.tar.gz")
|
|
|
|
+ client.get_object_to_file(oss_object_name, "model.tar.gz")
|
|
|
|
+ compress.uncompress_tar("model.tar.gz", init_model_path)
|
|
|
|
+ assert os.path.exists(init_model_path)
|
|
|
|
+
|
|
|
|
+ model_file=os.path.join(init_model_path, "dssm.pdmodel")
|
|
|
|
+ params_file=os.path.join(init_model_path, "dssm.pdiparams")
|
|
|
|
+
|
|
|
|
+ # 创建 config
|
|
|
|
+ config = paddle_infer.Config(model_file, params_file)
|
|
|
|
+
|
|
|
|
+ # 根据 config 创建 predictor
|
|
|
|
+ predictor = paddle_infer.create_predictor(config)
|
|
|
|
+
|
|
|
|
+ # 获取输入的名称
|
|
|
|
+ input_names = predictor.get_input_names()
|
|
|
|
+ input_handle = predictor.get_input_handle(input_names[0])
|
|
|
|
+
|
|
|
|
+ # 设置输入
|
|
|
|
+ fake_input = np.random.randn(1, 157).astype("float32")
|
|
|
|
+ input_handle.reshape([1, 157])
|
|
|
|
+ input_handle.copy_from_cpu(fake_input)
|
|
|
|
+
|
|
|
|
+ # 运行predictor
|
|
|
|
+ predictor.run()
|
|
|
|
+
|
|
|
|
+ # 获取输出
|
|
|
|
+ output_names = predictor.get_output_names()
|
|
|
|
+ output_handle = predictor.get_output_handle(output_names[0])
|
|
|
|
+ output_data = output_handle.copy_to_cpu() # numpy.ndarray类型
|
|
|
|
+ print("Output data size is {}".format(output_data.size))
|
|
|
|
+ print("Output data shape is {}".format(output_data.shape))
|
|
|
|
+
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
+ main()
|