Browse Source

dssm train

丁云鹏 4 months ago
parent
commit
c0df250bd9

+ 3 - 1
recommend-model-jni/src/main/java/com/baidu/paddle/inference/Predictor.java

@@ -1,6 +1,8 @@
 package com.baidu.paddle.inference;
 
-public class Predictor {
+import java.io.Serializable;
+
+public class Predictor implements Serializable {
 
     private long cppPaddlePredictorPointer;
 

+ 27 - 38
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/i2i/I2IDSSMService.java

@@ -13,7 +13,6 @@ import org.apache.commons.lang3.StringUtils;
 import org.apache.hadoop.io.compress.GzipCodec;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
 import org.apache.spark.ml.feature.VectorAssembler;
 import org.apache.spark.sql.Dataset;
@@ -24,7 +23,10 @@ import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 
-import java.util.*;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
 
 /**
  * @author dyp
@@ -39,6 +41,27 @@ public class I2IDSSMService {
         String file = argMap.get("path");
         int repartition = NumberUtils.toInt(argMap.get("repartition"), 64);
 
+        // 加载模型
+        String bucketName = "art-recommend";
+        String objectName = "dyp/dssm.tar.gz";
+        OSSService ossService = new OSSService();
+
+        String gzPath = "/root/recommend-model/model.tar.gz";
+        ossService.download(bucketName, gzPath, objectName);
+        String modelDir = "/root/recommend-model";
+        CompressUtil.decompressGzFile(gzPath, modelDir);
+
+        String modelFile = modelDir + "/dssm.pdmodel";
+        String paramFile = modelDir + "/dssm.pdiparams";
+
+        Config config = new Config();
+        config.setCppModel(modelFile, paramFile);
+        config.enableMemoryOptim(true);
+        config.enableMKLDNN();
+        config.switchIrDebug(false);
+
+        Predictor predictor = Predictor.createPaddlePredictor(config);
+
         SparkSession spark = SparkSession.builder()
                 .appName("I2IDSSMInfer")
                 .getOrCreate();
@@ -46,43 +69,9 @@ public class I2IDSSMService {
         JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
         JavaRDD<String> rdd = jsc.textFile(file);
 
-
         // 定义处理数据的函数
-        JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> {
-            // 在每个分区中初始化Predictor
-            String bucketName = "art-recommend";
-            String objectName = "dyp/dssm.tar.gz";
-            OSSService ossService = new OSSService();
-
-            String gzPath = "/root/recommend-model/model.tar.gz";
-            ossService.download(bucketName, gzPath, objectName);
-            String modelDir = "/root/recommend-model";
-            CompressUtil.decompressGzFile(gzPath, modelDir);
-
-            String modelFile = modelDir + "/dssm.pdmodel";
-            String paramFile = modelDir + "/dssm.pdiparams";
-
-            Config config = new Config();
-            config.setCppModel(modelFile, paramFile);
-            config.enableMemoryOptim(true);
-            config.enableMKLDNN();
-            config.switchIrDebug(false);
-
-            Predictor predictor = Predictor.createPaddlePredictor(config);
-
-            // 使用Predictor处理每个分区中的数据
-            return new Iterator<String>() {
-                private final Iterator<String> iterator = lines;
-                @Override
-                public boolean hasNext() {
-                    return iterator.hasNext();
-                }
-                @Override
-                public String next() {
-                    return processLine(iterator.next(), predictor);
-                }
-            };
-        });
+        JavaRDD<String> processedRdd = rdd.map(line -> processLine(line, predictor));
+
         // 将处理后的数据写入新的文件,使用Gzip压缩
         String outputPath = "hdfs:/dyp/vec2";
         processedRdd.coalesce(repartition).saveAsTextFile(outputPath, GzipCodec.class);