|
@@ -23,6 +23,7 @@ import org.apache.spark.sql.types.DataTypes;
|
|
|
import org.apache.spark.sql.types.StructField;
|
|
|
import org.apache.spark.sql.types.StructType;
|
|
|
|
|
|
+import java.io.Serializable;
|
|
|
import java.util.ArrayList;
|
|
|
import java.util.HashMap;
|
|
|
import java.util.List;
|
|
@@ -32,7 +33,7 @@ import java.util.Map;
|
|
|
* @author dyp
|
|
|
*/
|
|
|
@Slf4j
|
|
|
-public class I2IDSSMService {
|
|
|
+public class I2IDSSMService implements Serializable {
|
|
|
|
|
|
public void predict(String[] args) {
|
|
|
|
|
@@ -42,25 +43,25 @@ public class I2IDSSMService {
|
|
|
int repartition = NumberUtils.toInt(argMap.get("repartition"), 64);
|
|
|
|
|
|
// 加载模型
|
|
|
-// String bucketName = "art-recommend";
|
|
|
-// String objectName = "dyp/dssm.tar.gz";
|
|
|
-// OSSService ossService = new OSSService();
|
|
|
-//
|
|
|
-// String gzPath = "/root/recommend-model/model.tar.gz";
|
|
|
-// ossService.download(bucketName, gzPath, objectName);
|
|
|
-// String modelDir = "/root/recommend-model";
|
|
|
-// CompressUtil.decompressGzFile(gzPath, modelDir);
|
|
|
-//
|
|
|
-// String modelFile = modelDir + "/dssm.pdmodel";
|
|
|
-// String paramFile = modelDir + "/dssm.pdiparams";
|
|
|
-//
|
|
|
-// Config config = new Config();
|
|
|
-// config.setCppModel(modelFile, paramFile);
|
|
|
-// config.enableMemoryOptim(true);
|
|
|
-// config.enableMKLDNN();
|
|
|
-// config.switchIrDebug(false);
|
|
|
-//
|
|
|
-// Predictor predictor = Predictor.createPaddlePredictor(config);
|
|
|
+ String bucketName = "art-recommend";
|
|
|
+ String objectName = "dyp/dssm.tar.gz";
|
|
|
+ OSSService ossService = new OSSService();
|
|
|
+
|
|
|
+ String gzPath = "/root/recommend-model/model.tar.gz";
|
|
|
+ ossService.download(bucketName, gzPath, objectName);
|
|
|
+ String modelDir = "/root/recommend-model";
|
|
|
+ CompressUtil.decompressGzFile(gzPath, modelDir);
|
|
|
+
|
|
|
+ String modelFile = modelDir + "/dssm.pdmodel";
|
|
|
+ String paramFile = modelDir + "/dssm.pdiparams";
|
|
|
+
|
|
|
+ Config config = new Config();
|
|
|
+ config.setCppModel(modelFile, paramFile);
|
|
|
+ config.enableMemoryOptim(true);
|
|
|
+ config.enableMKLDNN();
|
|
|
+ config.switchIrDebug(false);
|
|
|
+
|
|
|
+ Predictor predictor = Predictor.createPaddlePredictor(config);
|
|
|
|
|
|
SparkSession spark = SparkSession.builder()
|
|
|
.appName("I2IDSSMInfer")
|
|
@@ -70,17 +71,13 @@ public class I2IDSSMService {
|
|
|
JavaRDD<String> rdd = jsc.textFile(file);
|
|
|
|
|
|
// 定义处理数据的函数
|
|
|
- JavaRDD<String> processedRdd = rdd.map(line -> processLine(line));
|
|
|
+ JavaRDD<String> processedRdd = rdd.map(line -> processLine(line, predictor));
|
|
|
|
|
|
// 将处理后的数据写入新的文件,使用Gzip压缩
|
|
|
String outputPath = "hdfs:/dyp/vec2";
|
|
|
processedRdd.coalesce(repartition).saveAsTextFile(outputPath, GzipCodec.class);
|
|
|
}
|
|
|
|
|
|
- private String processLine(String line) {
|
|
|
- return "";
|
|
|
- }
|
|
|
-
|
|
|
private String processLine(String line, Predictor predictor) {
|
|
|
|
|
|
// 1 处理数据
|