|  | @@ -1,51 +1,64 @@
 | 
	
		
			
				|  |  |  package com.tzld.piaoquan.recommend.model.service.model;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | +import com.baidu.paddle.inference.Config;
 | 
	
		
			
				|  |  | +import com.baidu.paddle.inference.Predictor;
 | 
	
		
			
				|  |  | +import com.baidu.paddle.inference.Tensor;
 | 
	
		
			
				|  |  | +import com.tzld.piaoquan.recommend.feature.util.JSONUtils;
 | 
	
		
			
				|  |  |  import com.tzld.piaoquan.recommend.model.util.CompressUtil;
 | 
	
		
			
				|  |  |  import com.tzld.piaoquan.recommend.model.util.PropertiesUtil;
 | 
	
		
			
				|  |  | -import org.apache.commons.lang.math.NumberUtils;
 | 
	
		
			
				|  |  | -import org.slf4j.Logger;
 | 
	
		
			
				|  |  | -import org.slf4j.LoggerFactory;
 | 
	
		
			
				|  |  | +import lombok.extern.slf4j.Slf4j;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import java.io.InputStream;
 | 
	
		
			
				|  |  | -import java.io.InputStreamReader;
 | 
	
		
			
				|  |  | -import java.util.Map;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +@Slf4j
 | 
	
		
			
				|  |  | +public class DSSMModel implements Model {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -public class DSSMModel extends Model {
 | 
	
		
			
				|  |  | -    private static final Logger LOGGER = LoggerFactory.getLogger(DSSMModel.class);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -private Predictor predictor;
 | 
	
		
			
				|  |  | -    @Override
 | 
	
		
			
				|  |  | -    public boolean loadFromStream(InputStreamReader in) throws Exception {
 | 
	
		
			
				|  |  | -        return false;
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  | +    private Predictor predictor;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      public void cleanModel() {
 | 
	
		
			
				|  |  | -        this.model = null;
 | 
	
		
			
				|  |  | +        this.predictor = null;
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    public Float infer() {
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        try {
 | 
	
		
			
				|  |  | -            float[] values = new float[features.length];
 | 
	
		
			
				|  |  | -            for (int i = 0; i < features.length; i++) {
 | 
	
		
			
				|  |  | -                float v = NumberUtils.toFloat(featureMap.getOrDefault(features[i], "0.0"), 0.0f);
 | 
	
		
			
				|  |  | -                values[i] = v;
 | 
	
		
			
				|  |  | -            }
 | 
	
		
			
				|  |  | -            DMatrix dm = new DMatrix(values, 1, features.length, 0.0f);
 | 
	
		
			
				|  |  | -            float[][] result = model._booster().predict(dm, false, 100);
 | 
	
		
			
				|  |  | -            return result[0][0];
 | 
	
		
			
				|  |  | -        } catch (Exception e) {
 | 
	
		
			
				|  |  | -            return 0f;
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | +    public String predict(String param) {
 | 
	
		
			
				|  |  | +        // 1 获取输入Tensor
 | 
	
		
			
				|  |  | +        String inNames = predictor.getInputNameById(0);
 | 
	
		
			
				|  |  | +        Tensor inHandle = predictor.getInputHandle(inNames);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // 2 设置输入
 | 
	
		
			
				|  |  | +        inHandle.reshape(4, new int[]{1, 3, 224, 224});
 | 
	
		
			
				|  |  | +        float[] inData = new float[1 * 3 * 224 * 224];
 | 
	
		
			
				|  |  | +        inHandle.copyFromCpu(inData);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // 3 预测
 | 
	
		
			
				|  |  | +        predictor.run();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // 4 获取输入Tensor
 | 
	
		
			
				|  |  | +        String outNames = predictor.getOutputNameById(0);
 | 
	
		
			
				|  |  | +        Tensor outHandle = predictor.getOutputHandle(outNames);
 | 
	
		
			
				|  |  | +        float[] outData = new float[outHandle.getSize()];
 | 
	
		
			
				|  |  | +        outHandle.copyToCpu(outData);
 | 
	
		
			
				|  |  | +        return JSONUtils.toJson(outData);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @Override
 | 
	
		
			
				|  |  |      public boolean loadFromStream(InputStream in) throws Exception {
 | 
	
		
			
				|  |  | -        String modelDir = PropertiesUtil.getString("model.xgboost.path");
 | 
	
		
			
				|  |  | +        String modelDir = PropertiesUtil.getString("model.dssm.path");
 | 
	
		
			
				|  |  |          CompressUtil.decompressGzFile(in, modelDir);
 | 
	
		
			
				|  |  | -        this.model = model2;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        String modelFile = "";
 | 
	
		
			
				|  |  | +        String paramFile = "";
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        Config config = new Config();
 | 
	
		
			
				|  |  | +        config.setCppModel(modelFile, paramFile);
 | 
	
		
			
				|  |  | +        config.enableMemoryOptim(true);
 | 
	
		
			
				|  |  | +        config.enableProfile();
 | 
	
		
			
				|  |  | +        config.enableMKLDNN();
 | 
	
		
			
				|  |  | +        config.getCpuMathLibraryNumThreads();
 | 
	
		
			
				|  |  | +        config.getFractionOfGpuMemoryForPool();
 | 
	
		
			
				|  |  | +        config.switchIrDebug(false);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        Predictor predictor2 = Predictor.createPaddlePredictor(config);
 | 
	
		
			
				|  |  | +        this.predictor = predictor2;
 | 
	
		
			
				|  |  |          return true;
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 |