Jelajahi Sumber

model service

丁云鹏 5 bulan lalu
induk
melakukan
a62826713b

+ 9 - 2
recommend-model-jni/src/main/c/com/baidu/paddle/inference/com_baidu_paddle_inference_Predictor.cpp

@@ -42,9 +42,16 @@ Java_com_baidu_paddle_inference_Predictor_predictorClearIntermediateTensor(
 
 JNIEXPORT jlong JNICALL
 Java_com_baidu_paddle_inference_Predictor_createPredictor(
-    JNIEnv* env, jobject obj, jlong cppPaddlePredictorPointer) {
+    JNIEnv* env, jobject obj, jlong cppPaddleConfigPointer) {
   return (jlong)PD_PredictorCreate(
-      reinterpret_cast<PD_Config*>(cppPaddlePredictorPointer));
+      reinterpret_cast<PD_Config*>(cppPaddleConfigPointer));
+}
+
+JNIEXPORT jlong JNICALL
+Java_com_baidu_paddle_inference_Predictor_clonePredictor(
+    JNIEnv* env, jobject obj, jlong cppPaddlePredictorPointer){
+  return (jlong)PD_PredictorClone(
+        reinterpret_cast<PD_Config*>(cppPaddlePredictorPointer));
 }
 
 JNIEXPORT jlong JNICALL Java_com_baidu_paddle_inference_Predictor_getInputNum(

+ 8 - 0
recommend-model-jni/src/main/c/com/baidu/paddle/inference/com_baidu_paddle_inference_Predictor.h

@@ -59,6 +59,14 @@ Java_com_baidu_paddle_inference_Predictor_createPredictor(JNIEnv *,
                                                           jobject,
                                                           jlong);
 
+/*
+ * Class:     com_baidu_paddle_inference_Predictor
+ * Method:    clonePredictor
+ * Signature: (J)J
+ */
+JNIEXPORT jlong JNICALL Java_com_baidu_paddle_inference_Predictor_clonePredictor
+  (JNIEnv *, jobject, jlong);
+
 /*
  * Class:     com_baidu_paddle_inference_Predictor
  * Method:    getInputNum

+ 28 - 8
recommend-model-jni/src/test/java/com/baidu/paddle/inference/test.java

@@ -27,15 +27,9 @@ public class test {
         System.out.println(config.summary());
 
         Predictor predictor = Predictor.createPaddlePredictor(config);
-
-        long n = predictor.getInputNum();
-
         String inNames = predictor.getInputNameById(0);
-
         Tensor inHandle = predictor.getInputHandle(inNames);
-
         inHandle.reshape(4, new int[]{1, 3, 224, 224});
-
         float[] inData = new float[1 * 3 * 224 * 224];
         inHandle.copyFromCpu(inData);
         predictor.run();
@@ -47,8 +41,10 @@ public class test {
         predictor.tryShrinkMemory();
         predictor.clearIntermediateTensor();
 
-        System.out.println(outData[0]);
-        System.out.println(outData.length);
+        System.out.println("predictor1: " + outData[0]);
+        System.out.println("predictor1: " + outData.length);
+
+        test(predictor);
 
         outHandle.destroyNativeTensor();
         inHandle.destroyNativeTensor();
@@ -63,5 +59,29 @@ public class test {
         System.out.println("params file:\n" + newConfig.getCppParamsFile());
         config.destroyNativeConfig();
 
+
+    }
+
+    private static void test(Predictor predictor) {
+        Predictor predictor2 = Predictor.clonePaddlePredictor(predictor);
+        String inNames = predictor.getInputNameById(0);
+        Tensor inHandle = predictor.getInputHandle(inNames);
+        inHandle.reshape(4, new int[]{1, 3, 224, 224});
+        float[] inData = new float[1 * 3 * 224 * 224];
+        inHandle.copyFromCpu(inData);
+        predictor.run();
+        String outNames = predictor.getOutputNameById(0);
+        Tensor outHandle = predictor.getOutputHandle(outNames);
+        float[] outData = new float[outHandle.getSize()];
+        outHandle.copyToCpu(outData);
+
+        predictor.tryShrinkMemory();
+        predictor.clearIntermediateTensor();
+
+        System.out.println("predictor2: " + outData[0]);
+        System.out.println("predictor2: " + outData.length);
+        outHandle.destroyNativeTensor();
+        inHandle.destroyNativeTensor();
+        predictor.destroyNativePredictor();
     }
 }

+ 2 - 2
recommend-model-service/src/main/java/com/tzld/piaoquan/recommend/model/service/model/DSSMModel.java

@@ -42,8 +42,8 @@ public class DSSMModel implements Model {
 
     @Override
     public boolean loadFromStream(InputStream in) throws Exception {
-        String modelDir = PropertiesUtil.getString("model.dssm.path");
-        CompressUtil.decompressGzFile(in, modelDir);
+        String modelDir = PropertiesUtil.getString("model.dir");
+        CompressUtil.decompressGzFile(in, modelDir + "/dssm");
 
         String modelFile = "";
         String paramFile = "";

+ 71 - 0
recommend-model-service/src/main/java/com/tzld/piaoquan/recommend/model/service/model/DemoModel.java

@@ -0,0 +1,71 @@
+package com.tzld.piaoquan.recommend.model.service.model;
+
+import com.baidu.paddle.inference.Config;
+import com.baidu.paddle.inference.Predictor;
+import com.baidu.paddle.inference.Tensor;
+import com.tzld.piaoquan.recommend.feature.util.JSONUtils;
+import com.tzld.piaoquan.recommend.model.util.CompressUtil;
+import com.tzld.piaoquan.recommend.model.util.PropertiesUtil;
+import lombok.extern.slf4j.Slf4j;
+
+import java.io.InputStream;
+
+@Slf4j
+public class DemoModel implements Model {
+
+    private Predictor sourcePredictor;
+
+    public void cleanModel() {
+        this.sourcePredictor.destroyNativePredictor();
+        this.sourcePredictor = null;
+    }
+
+    public String predict(String param) {
+        Predictor predictor = Predictor.clonePaddlePredictor(sourcePredictor);
+
+        String inNames = predictor.getInputNameById(0);
+        Tensor inHandle = predictor.getInputHandle(inNames);
+
+        inHandle.reshape(4, new int[]{1, 3, 224, 224});
+
+        float[] inData = new float[1 * 3 * 224 * 224];
+        inHandle.copyFromCpu(inData);
+        predictor.run();
+        String outNames = predictor.getOutputNameById(0);
+        Tensor outHandle = predictor.getOutputHandle(outNames);
+        float[] outData = new float[outHandle.getSize()];
+        outHandle.copyToCpu(outData);
+
+
+        outHandle.destroyNativeTensor();
+        inHandle.destroyNativeTensor();
+        predictor.destroyNativePredictor();
+
+        return JSONUtils.toJson(outData);
+    }
+
+    @Override
+    public boolean loadFromStream(InputStream in) throws Exception {
+        String modelDir = PropertiesUtil.getString("model.dir");
+        CompressUtil.decompressGzFile(in, modelDir + "/demo");
+
+        String modelFile = "";
+        String paramFile = "";
+
+        Config config = new Config();
+        config.setCppModel(modelFile, paramFile);
+        config.enableMemoryOptim(true);
+        config.enableProfile();
+        config.enableMKLDNN();
+        config.getCpuMathLibraryNumThreads();
+        config.getFractionOfGpuMemoryForPool();
+        config.switchIrDebug(false);
+
+        Predictor predictor = Predictor.createPaddlePredictor(config);
+        Predictor temp = predictor;
+        this.sourcePredictor = predictor;
+        temp.destroyNativePredictor();
+        return true;
+    }
+
+}

+ 1 - 0
recommend-model-service/src/main/java/com/tzld/piaoquan/recommend/model/service/model/ModelEnum.java

@@ -4,6 +4,7 @@ import org.apache.commons.lang3.StringUtils;
 
 public enum ModelEnum {
     VIDEO_DSSM("videoDssm", "", DSSMModel.class),
+    DEMO("demo", "", DemoModel.class),
     NULL("null", "null", null);
 
     private String modelName;

+ 4 - 0
recommend-model-service/src/main/resources/application.yml

@@ -14,3 +14,7 @@ apollo:
     enabled: true
     namespaces: application
   cacheDir: /datalog/apollo-cache-dir
+
+model:
+  dssm:
+    path: model