|
@@ -1,51 +1,64 @@
|
|
|
package com.tzld.piaoquan.recommend.model.service.model;
|
|
|
|
|
|
-
|
|
|
+import com.baidu.paddle.inference.Config;
|
|
|
+import com.baidu.paddle.inference.Predictor;
|
|
|
+import com.baidu.paddle.inference.Tensor;
|
|
|
+import com.tzld.piaoquan.recommend.feature.util.JSONUtils;
|
|
|
import com.tzld.piaoquan.recommend.model.util.CompressUtil;
|
|
|
import com.tzld.piaoquan.recommend.model.util.PropertiesUtil;
|
|
|
-import org.apache.commons.lang.math.NumberUtils;
|
|
|
-import org.slf4j.Logger;
|
|
|
-import org.slf4j.LoggerFactory;
|
|
|
+import lombok.extern.slf4j.Slf4j;
|
|
|
|
|
|
import java.io.InputStream;
|
|
|
-import java.io.InputStreamReader;
|
|
|
-import java.util.Map;
|
|
|
|
|
|
+@Slf4j
|
|
|
+public class DSSMModel implements Model {
|
|
|
|
|
|
-public class DSSMModel extends Model {
|
|
|
- private static final Logger LOGGER = LoggerFactory.getLogger(DSSMModel.class);
|
|
|
-
|
|
|
-private Predictor predictor;
|
|
|
- @Override
|
|
|
- public boolean loadFromStream(InputStreamReader in) throws Exception {
|
|
|
- return false;
|
|
|
- }
|
|
|
+ private Predictor predictor;
|
|
|
|
|
|
public void cleanModel() {
|
|
|
- this.model = null;
|
|
|
+ this.predictor = null;
|
|
|
}
|
|
|
|
|
|
- public Float infer() {
|
|
|
-
|
|
|
- try {
|
|
|
- float[] values = new float[features.length];
|
|
|
- for (int i = 0; i < features.length; i++) {
|
|
|
- float v = NumberUtils.toFloat(featureMap.getOrDefault(features[i], "0.0"), 0.0f);
|
|
|
- values[i] = v;
|
|
|
- }
|
|
|
- DMatrix dm = new DMatrix(values, 1, features.length, 0.0f);
|
|
|
- float[][] result = model._booster().predict(dm, false, 100);
|
|
|
- return result[0][0];
|
|
|
- } catch (Exception e) {
|
|
|
- return 0f;
|
|
|
- }
|
|
|
+ public String predict(String param) {
|
|
|
+ // 1 获取输入Tensor
|
|
|
+ String inNames = predictor.getInputNameById(0);
|
|
|
+ Tensor inHandle = predictor.getInputHandle(inNames);
|
|
|
+
|
|
|
+ // 2 设置输入
|
|
|
+ inHandle.reshape(4, new int[]{1, 3, 224, 224});
|
|
|
+ float[] inData = new float[1 * 3 * 224 * 224];
|
|
|
+ inHandle.copyFromCpu(inData);
|
|
|
+
|
|
|
+ // 3 预测
|
|
|
+ predictor.run();
|
|
|
+
|
|
|
+ // 4 获取输入Tensor
|
|
|
+ String outNames = predictor.getOutputNameById(0);
|
|
|
+ Tensor outHandle = predictor.getOutputHandle(outNames);
|
|
|
+ float[] outData = new float[outHandle.getSize()];
|
|
|
+ outHandle.copyToCpu(outData);
|
|
|
+ return JSONUtils.toJson(outData);
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public boolean loadFromStream(InputStream in) throws Exception {
|
|
|
- String modelDir = PropertiesUtil.getString("model.xgboost.path");
|
|
|
+ String modelDir = PropertiesUtil.getString("model.dssm.path");
|
|
|
CompressUtil.decompressGzFile(in, modelDir);
|
|
|
- this.model = model2;
|
|
|
+
|
|
|
+ String modelFile = "";
|
|
|
+ String paramFile = "";
|
|
|
+
|
|
|
+ Config config = new Config();
|
|
|
+ config.setCppModel(modelFile, paramFile);
|
|
|
+ config.enableMemoryOptim(true);
|
|
|
+ config.enableProfile();
|
|
|
+ config.enableMKLDNN();
|
|
|
+ config.getCpuMathLibraryNumThreads();
|
|
|
+ config.getFractionOfGpuMemoryForPool();
|
|
|
+ config.switchIrDebug(false);
|
|
|
+
|
|
|
+ Predictor predictor2 = Predictor.createPaddlePredictor(config);
|
|
|
+ this.predictor = predictor2;
|
|
|
return true;
|
|
|
}
|
|
|
|