Browse Source

dssm train

丁云鹏 4 months ago
parent
commit
4fb70486af

+ 19 - 7
recommend-model-service/src/main/java/com/tzld/piaoquan/recommend/model/service/model/DemoModel.java

@@ -17,25 +17,37 @@ public class DemoModel implements Model {
     private Predictor sourcePredictor;
 
     public void cleanModel() {
-        this.sourcePredictor.destroyNativePredictor();
         this.sourcePredictor = null;
     }
 
     public String predict(String param) {
-
         long time1 = System.currentTimeMillis();
+        // 1 获取输入Tensor
         Predictor predictor = Predictor.clonePaddlePredictor(sourcePredictor);
         String inNames = predictor.getInputNameById(0);
         Tensor inHandle = predictor.getInputHandle(inNames);
-        inHandle.reshape(4, new int[]{1, 3, 224, 224});
 
-        float[] inData = new float[1 * 3 * 224 * 224];
+        for (int i = 0; i < predictor.getInputNum(); i++) {
+            log.info("predictor2 inName{}={}", i, predictor.getInputNameById(i));
+        }
+
+        for (int i = 0; i < predictor.getOutputNum(); i++) {
+            log.info("predictor2 outName{}={}", i, predictor.getOutputNameById(i));
+        }
+        // 2 设置输入
+        inHandle.reshape(2, new int[]{1, 157});
+        float[] inData = new float[1 * 157];
         inHandle.copyFromCpu(inData);
+
+        // 3 预测
         predictor.run();
+
+        // 4 获取输入Tensor
         String outNames = predictor.getOutputNameById(0);
         Tensor outHandle = predictor.getOutputHandle(outNames);
         float[] outData = new float[outHandle.getSize()];
         outHandle.copyToCpu(outData);
+
         long time2 = System.currentTimeMillis();
         log.info("predictor2 outData[0]={},outDataLen={},cost={}", outData[0], outData.length, (time2 - time1));
 
@@ -48,11 +60,11 @@ public class DemoModel implements Model {
 
     @Override
     public boolean loadFromStream(InputStream in) throws Exception {
-        String modelDir = PropertiesUtil.getString("model.dir") + "/demo";
+        String modelDir = PropertiesUtil.getString("model.dir") + "/dssm";
         CompressUtil.decompressGzFile(in, modelDir);
 
-        String modelFile = modelDir + "/inference.pdmodel";
-        String paramFile = modelDir + "/inference.pdiparams";
+        String modelFile = modelDir + "/dssm.pdmodel";
+        String paramFile = modelDir + "/dssm.pdiparams";
 
         Config config = new Config();
         config.setCppModel(modelFile, paramFile);

+ 2 - 3
recommend-model-service/src/main/java/com/tzld/piaoquan/recommend/model/service/model/ModelEnum.java

@@ -3,9 +3,8 @@ package com.tzld.piaoquan.recommend.model.service.model;
 import org.apache.commons.lang3.StringUtils;
 
 public enum ModelEnum {
-    VIDEO_DSSM("videoDssm", "dyp/dssm_demo.tar.gz", DSSMModel.class),
-    DEMO("demo", "zhangbo/model_paddle_demo.tar.gz", DemoModel.class),
-    DNN("dnn", "dyp/dnn.tar.gz", DNNModel.class),
+    VIDEO_DSSM("videoDssm", "dyp/dssm.tar.gz", DSSMModel.class),
+    DEMO("demo", "dyp/dssm_demo.tar.gz", DemoModel.class),
     NULL("null", "null", null);
 
     private String modelName;

+ 0 - 7
recommend-model-service/src/main/java/com/tzld/piaoquan/recommend/model/web/DemoController.java

@@ -22,11 +22,4 @@ public class DemoController {
         String date = model.predict("");
         return date;
     }
-
-    @GetMapping("/dssm")
-    public String dssm() {
-        Model model = ModelManager.getInstance().getModel(ModelEnum.VIDEO_DSSM);
-        String date = model.predict("");
-        return date;
-    }
 }