瀏覽代碼

dssm train

丁云鹏 4 月之前
父節點
當前提交
460d027821

+ 36 - 28
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/i2i/I2IDSSMService.java

@@ -24,10 +24,7 @@ import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
 
 /**
  * @author dyp
@@ -42,27 +39,6 @@ public class I2IDSSMService {
         String file = argMap.get("path");
         int repartition = NumberUtils.toInt(argMap.get("repartition"), 64);
 
-        // 加载模型
-        String bucketName = "art-recommend";
-        String objectName = "dyp/dssm.tar.gz";
-        OSSService ossService = new OSSService();
-
-        String gzPath = "/root/recommend-model/model.tar.gz";
-        ossService.download(bucketName, gzPath, objectName);
-        String modelDir = "/root/recommend-model";
-        CompressUtil.decompressGzFile(gzPath, modelDir);
-
-        String modelFile = modelDir + "/dssm.pdmodel";
-        String paramFile = modelDir + "/dssm.pdiparams";
-
-        Config config = new Config();
-        config.setCppModel(modelFile, paramFile);
-        config.enableMemoryOptim(true);
-        config.enableMKLDNN();
-        config.switchIrDebug(false);
-
-        Predictor predictor = Predictor.createPaddlePredictor(config);
-
         SparkSession spark = SparkSession.builder()
                 .appName("I2IDSSMInfer")
                 .getOrCreate();
@@ -70,11 +46,43 @@ public class I2IDSSMService {
         JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
         JavaRDD<String> rdd = jsc.textFile(file);
 
-        Broadcast<Predictor> broadcastPredictor = jsc.broadcast(predictor);;
 
         // 定义处理数据的函数
-        JavaRDD<String> processedRdd = rdd.map(line -> processLine(line, broadcastPredictor.value()));
-
+        JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> {
+            // 在每个分区中初始化Predictor
+            String bucketName = "art-recommend";
+            String objectName = "dyp/dssm.tar.gz";
+            OSSService ossService = new OSSService();
+
+            String gzPath = "/root/recommend-model/model.tar.gz";
+            ossService.download(bucketName, gzPath, objectName);
+            String modelDir = "/root/recommend-model";
+            CompressUtil.decompressGzFile(gzPath, modelDir);
+
+            String modelFile = modelDir + "/dssm.pdmodel";
+            String paramFile = modelDir + "/dssm.pdiparams";
+
+            Config config = new Config();
+            config.setCppModel(modelFile, paramFile);
+            config.enableMemoryOptim(true);
+            config.enableMKLDNN();
+            config.switchIrDebug(false);
+
+            Predictor predictor = Predictor.createPaddlePredictor(config);
+
+            // 使用Predictor处理每个分区中的数据
+            return new Iterator<String>() {
+                private final Iterator<String> iterator = lines;
+                @Override
+                public boolean hasNext() {
+                    return iterator.hasNext();
+                }
+                @Override
+                public String next() {
+                    return processLine(iterator.next(), predictor);
+                }
+            };
+        });
         // 将处理后的数据写入新的文件,使用Gzip压缩
         String outputPath = "hdfs:/dyp/vec2";
         processedRdd.coalesce(repartition).saveAsTextFile(outputPath, GzipCodec.class);