|
@@ -13,7 +13,6 @@ import org.apache.commons.lang3.StringUtils;
|
|
|
import org.apache.hadoop.io.compress.GzipCodec;
|
|
|
import org.apache.spark.api.java.JavaRDD;
|
|
|
import org.apache.spark.api.java.JavaSparkContext;
|
|
|
-import org.apache.spark.broadcast.Broadcast;
|
|
|
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
|
|
|
import org.apache.spark.ml.feature.VectorAssembler;
|
|
|
import org.apache.spark.sql.Dataset;
|
|
@@ -24,7 +23,10 @@ import org.apache.spark.sql.types.DataTypes;
|
|
|
import org.apache.spark.sql.types.StructField;
|
|
|
import org.apache.spark.sql.types.StructType;
|
|
|
|
|
|
-import java.util.*;
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
|
|
|
/**
|
|
|
* @author dyp
|
|
@@ -39,6 +41,27 @@ public class I2IDSSMService {
|
|
|
String file = argMap.get("path");
|
|
|
int repartition = NumberUtils.toInt(argMap.get("repartition"), 64);
|
|
|
|
|
|
+ // 加载模型
|
|
|
+ String bucketName = "art-recommend";
|
|
|
+ String objectName = "dyp/dssm.tar.gz";
|
|
|
+ OSSService ossService = new OSSService();
|
|
|
+
|
|
|
+ String gzPath = "/root/recommend-model/model.tar.gz";
|
|
|
+ ossService.download(bucketName, gzPath, objectName);
|
|
|
+ String modelDir = "/root/recommend-model";
|
|
|
+ CompressUtil.decompressGzFile(gzPath, modelDir);
|
|
|
+
|
|
|
+ String modelFile = modelDir + "/dssm.pdmodel";
|
|
|
+ String paramFile = modelDir + "/dssm.pdiparams";
|
|
|
+
|
|
|
+ Config config = new Config();
|
|
|
+ config.setCppModel(modelFile, paramFile);
|
|
|
+ config.enableMemoryOptim(true);
|
|
|
+ config.enableMKLDNN();
|
|
|
+ config.switchIrDebug(false);
|
|
|
+
|
|
|
+ Predictor predictor = Predictor.createPaddlePredictor(config);
|
|
|
+
|
|
|
SparkSession spark = SparkSession.builder()
|
|
|
.appName("I2IDSSMInfer")
|
|
|
.getOrCreate();
|
|
@@ -46,43 +69,9 @@ public class I2IDSSMService {
|
|
|
JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
|
|
|
JavaRDD<String> rdd = jsc.textFile(file);
|
|
|
|
|
|
-
|
|
|
// 定义处理数据的函数
|
|
|
- JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> {
|
|
|
- // 在每个分区中初始化Predictor
|
|
|
- String bucketName = "art-recommend";
|
|
|
- String objectName = "dyp/dssm.tar.gz";
|
|
|
- OSSService ossService = new OSSService();
|
|
|
-
|
|
|
- String gzPath = "/root/recommend-model/model.tar.gz";
|
|
|
- ossService.download(bucketName, gzPath, objectName);
|
|
|
- String modelDir = "/root/recommend-model";
|
|
|
- CompressUtil.decompressGzFile(gzPath, modelDir);
|
|
|
-
|
|
|
- String modelFile = modelDir + "/dssm.pdmodel";
|
|
|
- String paramFile = modelDir + "/dssm.pdiparams";
|
|
|
-
|
|
|
- Config config = new Config();
|
|
|
- config.setCppModel(modelFile, paramFile);
|
|
|
- config.enableMemoryOptim(true);
|
|
|
- config.enableMKLDNN();
|
|
|
- config.switchIrDebug(false);
|
|
|
-
|
|
|
- Predictor predictor = Predictor.createPaddlePredictor(config);
|
|
|
-
|
|
|
- // 使用Predictor处理每个分区中的数据
|
|
|
- return new Iterator<String>() {
|
|
|
- private final Iterator<String> iterator = lines;
|
|
|
- @Override
|
|
|
- public boolean hasNext() {
|
|
|
- return iterator.hasNext();
|
|
|
- }
|
|
|
- @Override
|
|
|
- public String next() {
|
|
|
- return processLine(iterator.next(), predictor);
|
|
|
- }
|
|
|
- };
|
|
|
- });
|
|
|
+ JavaRDD<String> processedRdd = rdd.map(line -> processLine(line, predictor));
|
|
|
+
|
|
|
// 将处理后的数据写入新的文件,使用Gzip压缩
|
|
|
String outputPath = "hdfs:/dyp/vec2";
|
|
|
processedRdd.coalesce(repartition).saveAsTextFile(outputPath, GzipCodec.class);
|