瀏覽代碼

dssm train

丁云鹏 4 月之前
父節點
當前提交
3d9b6db9fb

+ 4 - 2
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/i2i/I2IDSSMService.java

@@ -13,6 +13,7 @@ import org.apache.commons.lang3.StringUtils;
 import org.apache.hadoop.io.compress.GzipCodec;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.broadcast.Broadcast;
 import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
 import org.apache.spark.ml.feature.VectorAssembler;
 import org.apache.spark.sql.Dataset;
@@ -64,14 +65,15 @@ public class I2IDSSMService {
 
         SparkSession spark = SparkSession.builder()
                 .appName("I2IDSSMInfer")
-                .master("local")
                 .getOrCreate();
 
         JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
         JavaRDD<String> rdd = jsc.textFile(file);
 
+        Broadcast<Predictor> broadcastPredictor = jsc.broadcast(predictor);;
+
         // 定义处理数据的函数
-        JavaRDD<String> processedRdd = rdd.map(line -> processLine(line, predictor));
+        JavaRDD<String> processedRdd = rdd.map(line -> processLine(line, broadcastPredictor.value()));
 
         // 将处理后的数据写入新的文件,使用Gzip压缩
         String outputPath = "hdfs:/dyp/vec2";