Ver Fonte

dssm train

丁云鹏 há 4 meses atrás
pai
commit
0a109a25cc

+ 5 - 0
recommend-model-produce/pom.xml

@@ -204,6 +204,11 @@
                 </exclusion>
             </exclusions>
         </dependency>
+        <dependency>
+            <groupId>com.tzld.piaoquan</groupId>
+            <artifactId>recommend-model-jni</artifactId>
+            <version>1.0.0</version>
+        </dependency>
     </dependencies>
     <build>
         <plugins>

+ 66 - 80
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/i2i/I2IDSSMService.java

@@ -1,5 +1,8 @@
 package com.tzld.piaoquan.recommend.model.produce.i2i;
 
+import com.baidu.paddle.inference.Config;
+import com.baidu.paddle.inference.Predictor;
+import com.baidu.paddle.inference.Tensor;
 import com.tzld.piaoquan.recommend.model.produce.service.CMDService;
 import com.tzld.piaoquan.recommend.model.produce.service.OSSService;
 import com.tzld.piaoquan.recommend.model.produce.util.CompressUtil;
@@ -7,6 +10,7 @@ import lombok.extern.slf4j.Slf4j;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
 import org.apache.commons.lang.math.NumberUtils;
 import org.apache.commons.lang3.StringUtils;
+import org.apache.hadoop.io.compress.GzipCodec;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
@@ -32,103 +36,85 @@ public class I2IDSSMService {
 
     public void predict(String[] args) {
 
-        try {
+        CMDService cmd = new CMDService();
+        Map<String, String> argMap = cmd.parse(args);
+        String file = argMap.get("path");
 
-            CMDService cmd = new CMDService();
-            Map<String, String> argMap = cmd.parse(args);
-            String file = argMap.get("path");
+        // 加载模型
+        String bucketName = "art-recommend";
+        String objectName = "dyp/dssm.tar.gz";
+        OSSService ossService = new OSSService();
 
-            if(StringUtils.isBlank(file)){
-                String dir = argMap.get("dir");
-            }
-            
-            // 加载模型
-            String bucketName = "art-test-video";
-            String objectName = "test/model.tar.gz";
-            OSSService ossService = new OSSService();
-
-            String gzPath = "/root/recommend-model/model2.tar.gz";
-            ossService.download(bucketName, gzPath, objectName);
-            String modelDir = "/root/recommend-model/modelpredict";
-            CompressUtil.decompressGzFile(gzPath, modelDir);
+        String gzPath = "/root/recommend-model/model.tar.gz";
+        ossService.download(bucketName, gzPath, objectName);
+        String modelDir = "/root/recommend-model";
+        CompressUtil.decompressGzFile(gzPath, modelDir);
 
+        String modelFile = modelDir + "/dssm.pdmodel";
+        String paramFile = modelDir + "/dssm.pdiparams";
 
-        } catch (Throwable e) {
-            log.error("", e);
-        }
-    }
-
-    private static Dataset<Row> dataset(String path) {
-        String[] features = {
-                "cpa",
-                "b2_1h_ctr",
-                "b2_1h_ctcvr",
-                "b2_1h_cvr",
-                "b2_1h_conver",
-                "b2_1h_click",
-                "b2_1h_conver*log(view)",
-                "b2_1h_conver*ctcvr",
-                "b2_2h_ctr",
-                "b2_2h_ctcvr",
-                "b2_2h_cvr",
-                "b2_2h_conver",
-                "b2_2h_click",
-                "b2_2h_conver*log(view)",
-                "b2_2h_conver*ctcvr",
-                "b2_3h_ctr",
-                "b2_3h_ctcvr",
-                "b2_3h_cvr",
-                "b2_3h_conver",
-                "b2_3h_click",
-                "b2_3h_conver*log(view)",
-                "b2_3h_conver*ctcvr",
-                "b2_6h_ctr",
-                "b2_6h_ctcvr"
-        };
+        Config config = new Config();
+        config.setCppModel(modelFile, paramFile);
+        config.enableMemoryOptim(true);
+        config.enableMKLDNN();
+        config.switchIrDebug(false);
 
+        Predictor predictor = Predictor.createPaddlePredictor(config);
 
         SparkSession spark = SparkSession.builder()
-                .appName("XGBoostTrain")
+                .appName("I2IDSSMInfer")
                 .getOrCreate();
 
         JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
-        String file = path;
         JavaRDD<String> rdd = jsc.textFile(file);
 
-        JavaRDD<Row> rowRDD = rdd.map(s -> {
-            String[] line = StringUtils.split(s, '\t');
-            double label = NumberUtils.toDouble(line[0]);
-            // 选特征
-            Map<String, Double> map = new HashMap<>();
-            for (int i = 1; i < line.length; i++) {
-                String[] fv = StringUtils.split(line[i], ':');
-                map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
-            }
+        // 定义处理数据的函数
+        JavaRDD<String> processedRdd = rdd.map(line -> processLine(line, predictor));
+
+        // 将处理后的数据写入新的文件,使用Gzip压缩
+        String outputPath = "hdfs:/dyp/vec2";
+        processedRdd.saveAsTextFile(outputPath, GzipCodec.class);
+    }
+
+    private String processLine(String line, Predictor predictor) {
+
+        // 1 处理数据
+        String[] sampleValues = line.split("\t", -1); // -1参数保持尾部空字符串
+
+        // 检查是否有至少两个元素(vid和left_features_str)
+        if (sampleValues.length >= 2) {
+            String vid = sampleValues[0];
+            String leftFeaturesStr = sampleValues[1];
 
-            Object[] v = new Object[features.length + 1];
-            v[0] = label;
-            for (int i = 0; i < features.length; i++) {
-                v[i + 1] = map.getOrDefault(features[i], 0.0d);
+            // 分割left_features_str并转换为float数组
+            String[] leftFeaturesArray = leftFeaturesStr.split(",");
+            float[] leftFeatures = new float[leftFeaturesArray.length];
+            for (int i = 0; i < leftFeaturesArray.length; i++) {
+                leftFeatures[i] = Float.parseFloat(leftFeaturesArray[i]);
             }
+            String inNames = predictor.getInputNameById(0);
+            Tensor inHandle = predictor.getInputHandle(inNames);
+            // 2 设置输入
+            inHandle.reshape(2, new int[]{1, 157});
+            inHandle.copyFromCpu(leftFeatures);
 
-            return RowFactory.create(v);
-        });
+            // 3 预测
+            predictor.run();
 
-        log.info("rowRDD count {}", rowRDD.count());
-        // 将 JavaRDD<Row> 转换为 Dataset<Row>
-        List<StructField> fields = new ArrayList<>();
-        fields.add(DataTypes.createStructField("label", DataTypes.DoubleType, true));
-        for (String f : features) {
-            fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
-        }
-        StructType schema = DataTypes.createStructType(fields);
-        Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
+            // 4 获取输入Tensor
+            String outNames = predictor.getOutputNameById(0);
+            Tensor outHandle = predictor.getOutputHandle(outNames);
+            float[] outData = new float[outHandle.getSize()];
+            outHandle.copyToCpu(outData);
+
+            String result = vid + "\t" + outData[0];
 
-        VectorAssembler assembler = new VectorAssembler()
-                .setInputCols(features)
-                .setOutputCol("features");
+            outHandle.destroyNativeTensor();
+            inHandle.destroyNativeTensor();
+            predictor.destroyNativePredictor();
 
-        Dataset<Row> assembledData = assembler.transform(dataset);
-        return assembledData;
+            return result;
+        }
+        return "";
     }
 }