丁云鹏 пре 6 месеци
родитељ
комит
4c38c745f9
1 измењених фајлова са 44 додато и 1 уклоњено
  1. 44 1
      recommend-model-produce/src/main/python/tools/inferv2.py

+ 44 - 1
recommend-model-produce/src/main/python/tools/inferv2.py

@@ -5,7 +5,9 @@ from utils.oss_client import HangZhouOSSClient
 import utils.compress as compress
 from utils.my_hdfs_client import MyHDFSClient
 import logging
-import paddle
+# 引用 paddle inference 推理库
+import paddle.inference as paddle_infer
+
 
 logging.basicConfig(
     format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
@@ -19,3 +21,44 @@ configs = {
 hdfs_client = MyHDFSClient(hadoop_home, configs)
 
 
+
+
+def main():
+
+    init_model_path = "/app/output_model_dssm"
+    client = HangZhouOSSClient("art-recommend")
+    oss_object_name = self.config.get("runner.oss_object_name", "dyp/model.tar.gz")
+    client.get_object_to_file(oss_object_name, "model.tar.gz")
+    compress.uncompress_tar("model.tar.gz", init_model_path)
+    assert os.path.exists(init_model_path)
+
+    model_file=os.path.join(init_model_path, "dssm.pdmodel")
+    params_file=os.path.join(init_model_path, "dssm.pdiparams")
+
+    # 创建 config
+    config = paddle_infer.Config(model_file, params_file)
+
+    # 根据 config 创建 predictor
+    predictor = paddle_infer.create_predictor(config)
+
+    # 获取输入的名称
+    input_names = predictor.get_input_names()
+    input_handle = predictor.get_input_handle(input_names[0])
+
+    # 设置输入
+    fake_input = np.random.randn(1, 157).astype("float32")
+    input_handle.reshape([1, 157])
+    input_handle.copy_from_cpu(fake_input)
+
+    # 运行predictor
+    predictor.run()
+
+    # 获取输出
+    output_names = predictor.get_output_names()
+    output_handle = predictor.get_output_handle(output_names[0])
+    output_data = output_handle.copy_to_cpu() # numpy.ndarray类型
+    print("Output data size is {}".format(output_data.size))
+    print("Output data shape is {}".format(output_data.shape))
+
+if __name__ == "__main__":
+    main()