Bläddra i källkod

model service

丁云鹏 5 månader sedan
förälder
incheckning
0264ac593e

+ 10 - 0
recommend-model-service/pom.xml

@@ -73,6 +73,11 @@
             <artifactId>recommend-model-client</artifactId>
             <version>1.0.0</version>
         </dependency>
+        <dependency>
+            <groupId>com.tzld.piaoquan</groupId>
+            <artifactId>recommend-model-jni</artifactId>
+            <version>1.0.0</version>
+        </dependency>
         <dependency>
             <groupId>com.google.protobuf</groupId>
             <artifactId>protobuf-java</artifactId>
@@ -109,6 +114,11 @@
             <artifactId>aliyun-sdk-oss</artifactId>
             <version>3.15.1</version>
         </dependency>
+        <dependency>
+            <groupId>org.apache.commons</groupId>
+            <artifactId>commons-compress</artifactId>
+            <version>1.21</version>
+        </dependency>
     </dependencies>
     <build>
         <finalName>recommend-model-service</finalName>

+ 25 - 0
recommend-model-service/src/main/java/com/tzld/piaoquan/recommend/model/grpcservice/GrpcAspect.java

@@ -0,0 +1,25 @@
+package com.tzld.piaoquan.recommend.model.grpcservice;
+
+import com.tzld.piaoquan.recommend.feature.util.TraceUtils;
+import lombok.extern.slf4j.Slf4j;
+import org.aspectj.lang.ProceedingJoinPoint;
+import org.aspectj.lang.annotation.Around;
+import org.aspectj.lang.annotation.Aspect;
+import org.springframework.stereotype.Component;
+
+/**
+ * @author dyp
+ */
+@Aspect
+@Component
+@Slf4j
+public class GrpcAspect {
+
+    @Around("execution(* com.tzld.piaoquan.recommend.model.grpcservice.*GrpcService.*(..))")
+    public Object around(ProceedingJoinPoint pjp) throws Throwable {
+        TraceUtils.setMDC();
+        Object result = pjp.proceed();
+        TraceUtils.removeMDC();
+        return result;
+    }
+}

+ 33 - 0
recommend-model-service/src/main/java/com/tzld/piaoquan/recommend/model/grpcservice/PredictGrpcService.java

@@ -0,0 +1,33 @@
+package com.tzld.piaoquan.recommend.model.grpcservice;
+
+import com.tzld.piaoquan.recommend.model.grpc.model.PredictRequest;
+import com.tzld.piaoquan.recommend.model.grpc.model.PredictResponse;
+import com.tzld.piaoquan.recommend.model.grpc.model.PredictServiceGrpc;
+import com.tzld.piaoquan.recommend.model.service.PredictService;
+import io.grpc.stub.StreamObserver;
+import lombok.extern.slf4j.Slf4j;
+import net.devh.boot.grpc.server.service.GrpcService;
+import org.springframework.beans.factory.annotation.Autowired;
+
+/**
+ * @author dyp
+ */
+@GrpcService
+@Slf4j
+public class PredictGrpcService extends PredictServiceGrpc.PredictServiceImplBase {
+
+    @Autowired
+    private PredictService predictService;
+
+    @Override
+    public void predict(PredictRequest request, StreamObserver<PredictResponse> responseObserver) {
+
+        //log.info("PredictGrpcService predict request={}", ProtobufUtils.toJson(request));
+        PredictResponse response = predictService.predict(request);
+        //log.info("PredictGrpcService predict response={}", ProtobufUtils.toJson(response));
+
+        responseObserver.onNext(response);
+        responseObserver.onCompleted();
+    }
+
+}

+ 40 - 0
recommend-model-service/src/main/java/com/tzld/piaoquan/recommend/model/service/PredictService.java

@@ -0,0 +1,40 @@
+package com.tzld.piaoquan.recommend.model.service;
+
+import com.google.common.base.Strings;
+import com.tzld.piaoquan.recommend.model.grpc.model.PredictRequest;
+import com.tzld.piaoquan.recommend.model.grpc.model.PredictResponse;
+import com.tzld.piaoquan.recommend.model.grpc.model.common.Result;
+import com.tzld.piaoquan.recommend.model.service.model.Model;
+import com.tzld.piaoquan.recommend.model.service.model.ModelEnum;
+import com.tzld.piaoquan.recommend.model.service.model.ModelManager;
+
+/**
+ * @author dyp
+ */
+public class PredictService {
+    public PredictResponse predict(PredictRequest request) {
+        String modelName = request.getModelName();
+
+        ModelEnum modelEnum = ModelEnum.which(modelName);
+
+        Model model = ModelManager.getInstance().getModel(modelEnum);
+        if (model == null) {
+            return PredictResponse.newBuilder()
+                    .setResult(Result.newBuilder()
+                            .setCode(2)
+                            .setMessage("Model [" + modelName + "] not support!"))
+                    .build();
+        }
+
+        String data = model.predict(request.getParam());
+
+        return PredictResponse.newBuilder()
+                .setResult(Result.newBuilder()
+                        .setCode(1)
+                        .setMessage("success"))
+                .setData(Strings.nullToEmpty(data))
+                .build();
+
+
+    }
+}

+ 5 - 1
recommend-model-service/src/main/java/com/tzld/piaoquan/recommend/model/service/WarmUpService.java

@@ -1,5 +1,8 @@
 package com.tzld.piaoquan.recommend.model.service;
 
+import com.tzld.piaoquan.recommend.model.service.model.ModelEnum;
+import com.tzld.piaoquan.recommend.model.service.model.ModelManager;
+import lombok.extern.slf4j.Slf4j;
 import org.springframework.stereotype.Component;
 
 import javax.annotation.PostConstruct;
@@ -8,9 +11,10 @@ import javax.annotation.PostConstruct;
  * @author dyp
  */
 @Component
+@Slf4j
 public class WarmUpService {
     @PostConstruct
     public void warmup() {
-
+        ModelManager.getInstance().registerModel(ModelEnum.VIDEO_DSSM);
     }
 }

+ 44 - 31
recommend-model-service/src/main/java/com/tzld/piaoquan/recommend/model/service/model/DSSMModel.java

@@ -1,51 +1,64 @@
 package com.tzld.piaoquan.recommend.model.service.model;
 
-
+import com.baidu.paddle.inference.Config;
+import com.baidu.paddle.inference.Predictor;
+import com.baidu.paddle.inference.Tensor;
+import com.tzld.piaoquan.recommend.feature.util.JSONUtils;
 import com.tzld.piaoquan.recommend.model.util.CompressUtil;
 import com.tzld.piaoquan.recommend.model.util.PropertiesUtil;
-import org.apache.commons.lang.math.NumberUtils;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import lombok.extern.slf4j.Slf4j;
 
 import java.io.InputStream;
-import java.io.InputStreamReader;
-import java.util.Map;
 
+@Slf4j
+public class DSSMModel implements Model {
 
-public class DSSMModel extends Model {
-    private static final Logger LOGGER = LoggerFactory.getLogger(DSSMModel.class);
-
-private Predictor predictor;
-    @Override
-    public boolean loadFromStream(InputStreamReader in) throws Exception {
-        return false;
-    }
+    private Predictor predictor;
 
     public void cleanModel() {
-        this.model = null;
+        this.predictor = null;
     }
 
-    public Float infer() {
-
-        try {
-            float[] values = new float[features.length];
-            for (int i = 0; i < features.length; i++) {
-                float v = NumberUtils.toFloat(featureMap.getOrDefault(features[i], "0.0"), 0.0f);
-                values[i] = v;
-            }
-            DMatrix dm = new DMatrix(values, 1, features.length, 0.0f);
-            float[][] result = model._booster().predict(dm, false, 100);
-            return result[0][0];
-        } catch (Exception e) {
-            return 0f;
-        }
+    public String predict(String param) {
+        // 1 获取输入Tensor
+        String inNames = predictor.getInputNameById(0);
+        Tensor inHandle = predictor.getInputHandle(inNames);
+
+        // 2 设置输入
+        inHandle.reshape(4, new int[]{1, 3, 224, 224});
+        float[] inData = new float[1 * 3 * 224 * 224];
+        inHandle.copyFromCpu(inData);
+
+        // 3 预测
+        predictor.run();
+
+        // 4 获取输入Tensor
+        String outNames = predictor.getOutputNameById(0);
+        Tensor outHandle = predictor.getOutputHandle(outNames);
+        float[] outData = new float[outHandle.getSize()];
+        outHandle.copyToCpu(outData);
+        return JSONUtils.toJson(outData);
     }
 
     @Override
     public boolean loadFromStream(InputStream in) throws Exception {
-        String modelDir = PropertiesUtil.getString("model.xgboost.path");
+        String modelDir = PropertiesUtil.getString("model.dssm.path");
         CompressUtil.decompressGzFile(in, modelDir);
-        this.model = model2;
+
+        String modelFile = "";
+        String paramFile = "";
+
+        Config config = new Config();
+        config.setCppModel(modelFile, paramFile);
+        config.enableMemoryOptim(true);
+        config.enableProfile();
+        config.enableMKLDNN();
+        config.getCpuMathLibraryNumThreads();
+        config.getFractionOfGpuMemoryForPool();
+        config.switchIrDebug(false);
+
+        Predictor predictor2 = Predictor.createPaddlePredictor(config);
+        this.predictor = predictor2;
         return true;
     }
 

+ 4 - 5
recommend-model-service/src/main/java/com/tzld/piaoquan/recommend/model/service/model/Model.java

@@ -4,11 +4,10 @@ package com.tzld.piaoquan.recommend.model.service.model;
 import java.io.InputStream;
 import java.io.InputStreamReader;
 
-abstract public class Model {
+public interface Model {
 
-    public abstract boolean loadFromStream(InputStreamReader in) throws Exception;
-    public boolean loadFromStream(InputStream is) throws Exception {
-        return loadFromStream(new InputStreamReader(is));
-    }
+    boolean loadFromStream(InputStream in) throws Exception;
+
+    String predict(String input);
 }
 

+ 39 - 0
recommend-model-service/src/main/java/com/tzld/piaoquan/recommend/model/service/model/ModelEnum.java

@@ -0,0 +1,39 @@
+package com.tzld.piaoquan.recommend.model.service.model;
+
+import org.apache.commons.lang3.StringUtils;
+
+public enum ModelEnum {
+    VIDEO_DSSM("videoDssm", "", DSSMModel.class),
+    NULL("null", "null", null);
+
+    private String modelName;
+    private String modelOssPath;
+    private Class<? extends Model> modelClass;
+
+    ModelEnum(String modelName, String modelOssPath, Class<? extends Model> modelClass) {
+        this.modelName = modelName;
+        this.modelOssPath = modelOssPath;
+        this.modelClass = modelClass;
+    }
+
+    public static ModelEnum which(String modelName) {
+        for (ModelEnum modelEnum : ModelEnum.values()) {
+            if (StringUtils.equals(modelEnum.getModelName(), modelName)) {
+                return modelEnum;
+            }
+        }
+        return NULL;
+    }
+
+    public String getModelName() {
+        return modelName;
+    }
+
+    public String getModelOssPath() {
+        return modelOssPath;
+    }
+
+    public Class<? extends Model> getModelClass() {
+        return modelClass;
+    }
+}

+ 11 - 16
recommend-model-service/src/main/java/com/tzld/piaoquan/recommend/model/service/model/ModelManager.java

@@ -9,10 +9,8 @@ import com.aliyun.oss.model.OSSObject;
 import com.ctrip.framework.apollo.Config;
 import com.ctrip.framework.apollo.ConfigService;
 import lombok.extern.slf4j.Slf4j;
-import org.springframework.ui.Model;
 
 import java.io.IOException;
-import java.io.InputStreamReader;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.Executors;
@@ -79,12 +77,13 @@ public class ModelManager {
 
     /**
      * 添加一个加载任务到管理器
-     *
-     * @param modelName  Model的名字, 注册到ModelManager的不同model需要不同的名字
-     * @param path       Model在OSS上的全路径
-     * @param modelClass Model的子类型
      */
-    public void registerModel(String modelName, String path, Class<? extends Model> modelClass) throws ModelRegisterException, IOException {
+    public void registerModel(ModelEnum modelEnum) {
+
+        String modelName = modelEnum.getModelName();
+        String path = modelEnum.getModelOssPath();
+        Class<? extends Model> modelClass = modelEnum.getModelClass();
+
         if (modelPathMap.containsKey(modelName)) {
             // fail fast
             // throw new RuntimeException(modelName + " already exists");
@@ -107,10 +106,9 @@ public class ModelManager {
 
     /**
      * 删除一个加载任务
-     *
-     * @param modelName Model的名字, 需要和registerModel的名字一致
      */
-    private void unRegisterModel(String modelName) {
+    private void unRegisterModel(ModelEnum modelEnum) {
+        String modelName = modelEnum.getModelName();
         if (modelPathMap.containsKey(modelName)) {
             String path = modelPathMap.get(modelName);
             if (loadTasks.containsKey(path)) {
@@ -124,11 +122,8 @@ public class ModelManager {
         }
     }
 
-    /**
-     * @param modelName
-     * @return
-     */
-    public Model getModel(String modelName) {
+    public Model getModel(ModelEnum modelEnum) {
+        String modelName = modelEnum.getModelName();
         if (modelPathMap.containsKey(modelName) && loadTasks.containsKey(modelPathMap.get(modelName))) {
             return loadTasks.get(modelPathMap.get(modelName)).model;
         } else {
@@ -190,7 +185,7 @@ public class ModelManager {
                         loadTask.lastModifyTime, timeStamp);
 
                 Model model = loadTask.modelClass.newInstance();
-                if (model.loadFromStream(new InputStreamReader(ossObj.getObjectContent()))) {
+                if (model.loadFromStream(ossObj.getObjectContent())) {
                     loadTask.model = model;
                     loadTask.lastModifyTime = timeStamp;
                 }

+ 0 - 15
recommend-model-service/src/main/java/com/tzld/piaoquan/recommend/model/service/model/OssConfig.java

@@ -1,15 +0,0 @@
-package com.tzld.piaoquan.recommend.model.service.model;
-
-import lombok.Data;
-
-/**
- * @author dyp
- */
-@Data
-public class OssConfig {
-
-    private String accessKeyId;
-    private String accessKeySecret;
-    private String endpoint;
-    private String bucketName;
-}