Kaynağa Gözat

model service

丁云鹏 5 ay önce
ebeveyn
işleme
beb6169180

+ 21 - 9
recommend-model-service/src/main/java/com/tzld/piaoquan/recommend/model/service/model/DSSMModel.java

@@ -13,14 +13,16 @@ import java.io.InputStream;
 @Slf4j
 public class DSSMModel implements Model {
 
-    private Predictor predictor;
+    private Predictor sourcePredictor;
 
     public void cleanModel() {
-        this.predictor = null;
+        this.sourcePredictor = null;
     }
 
     public String predict(String param) {
+        long time1 = System.currentTimeMillis();
         // 1 获取输入Tensor
+        Predictor predictor = Predictor.clonePaddlePredictor(sourcePredictor);
         String inNames = predictor.getInputNameById(0);
         Tensor inHandle = predictor.getInputHandle(inNames);
 
@@ -37,28 +39,38 @@ public class DSSMModel implements Model {
         Tensor outHandle = predictor.getOutputHandle(outNames);
         float[] outData = new float[outHandle.getSize()];
         outHandle.copyToCpu(outData);
+
+        long time2 = System.currentTimeMillis();
+        log.info("predictor2 outData[0]={},outDataLen={},cost={}", outData[0], outData.length, (time2 - time1));
+
+        outHandle.destroyNativeTensor();
+        inHandle.destroyNativeTensor();
+        predictor.destroyNativePredictor();
+
         return JSONUtils.toJson(outData);
     }
 
     @Override
     public boolean loadFromStream(InputStream in) throws Exception {
-        String modelDir = PropertiesUtil.getString("model.dir");
-        CompressUtil.decompressGzFile(in, modelDir + "/dssm");
+        String modelDir = PropertiesUtil.getString("model.dir") + "/demo";
+        CompressUtil.decompressGzFile(in, modelDir);
 
-        String modelFile = "";
-        String paramFile = "";
+        String modelFile = modelDir + "/inference.pdmodel";
+        String paramFile = modelDir + "/inference.pdiparams";
 
         Config config = new Config();
         config.setCppModel(modelFile, paramFile);
         config.enableMemoryOptim(true);
-        config.enableProfile();
         config.enableMKLDNN();
         config.getCpuMathLibraryNumThreads();
         config.getFractionOfGpuMemoryForPool();
         config.switchIrDebug(false);
 
-        Predictor predictor2 = Predictor.createPaddlePredictor(config);
-        this.predictor = predictor2;
+        Predictor temp = sourcePredictor;
+        sourcePredictor = Predictor.createPaddlePredictor(config);
+        if (temp != null) {
+            temp.destroyNativePredictor();
+        }
         return true;
     }