丁云鹏 4 meses atrás
pai
commit
14a452673e

+ 37 - 26
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/i2i/I2IDSSMService.java

@@ -24,10 +24,7 @@ import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 
 import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
 
 /**
  * @author dyp
@@ -43,26 +40,6 @@ public class I2IDSSMService implements Serializable {
         int repartition = NumberUtils.toInt(argMap.get("repartition"), 64);
 
         // 加载模型
-        String bucketName = "art-recommend";
-        String objectName = "dyp/dssm.tar.gz";
-        OSSService ossService = new OSSService();
-
-        String gzPath = "/root/recommend-model/model.tar.gz";
-        ossService.download(bucketName, gzPath, objectName);
-        String modelDir = "/root/recommend-model";
-        CompressUtil.decompressGzFile(gzPath, modelDir);
-
-        String modelFile = modelDir + "/dssm.pdmodel";
-        String paramFile = modelDir + "/dssm.pdiparams";
-
-        Config config = new Config();
-        config.setCppModel(modelFile, paramFile);
-        config.enableMemoryOptim(true);
-        config.enableMKLDNN();
-        config.switchIrDebug(false);
-
-        Predictor predictor = Predictor.createPaddlePredictor(config);
-
         SparkSession spark = SparkSession.builder()
                 .appName("I2IDSSMInfer")
                 .getOrCreate();
@@ -71,8 +48,42 @@ public class I2IDSSMService implements Serializable {
         JavaRDD<String> rdd = jsc.textFile(file);
 
         // 定义处理数据的函数
-        JavaRDD<String> processedRdd = rdd.map(line -> processLine(line, predictor));
-
+        JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> {
+            String bucketName = "art-recommend";
+            String objectName = "dyp/dssm.tar.gz";
+            OSSService ossService = new OSSService();
+
+            String gzPath = "/root/recommend-model/model.tar.gz";
+            ossService.download(bucketName, gzPath, objectName);
+            String modelDir = "/root/recommend-model";
+            CompressUtil.decompressGzFile(gzPath, modelDir);
+
+            String modelFile = modelDir + "/dssm.pdmodel";
+            String paramFile = modelDir + "/dssm.pdiparams";
+
+            Config config = new Config();
+            config.setCppModel(modelFile, paramFile);
+            config.enableMemoryOptim(true);
+            config.enableMKLDNN();
+            config.switchIrDebug(false);
+
+            Predictor predictor = Predictor.createPaddlePredictor(config);
+
+
+            return new Iterator<String>() {
+                private final Iterator<String> iterator = lines;
+
+                @Override
+                public boolean hasNext() {
+                    return iterator.hasNext();
+                }
+
+                @Override
+                public String next() {
+                    return processLine(iterator.next(), predictor);
+                }
+            };
+        });
         // 将处理后的数据写入新的文件,使用Gzip压缩
         String outputPath = "hdfs:/dyp/vec2";
         processedRdd.coalesce(repartition).saveAsTextFile(outputPath, GzipCodec.class);

+ 2 - 2
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/service/OSSService.java

@@ -17,8 +17,8 @@ import java.io.Serializable;
 public class OSSService implements Serializable {
     private String accessId = "LTAI5tHMkNaRhpiDB1yWMZPn";
     private String accessKey = "XLi5YUJusVwbbQOaGeGsaRJ1Qyzbui";
-    //private String endpoint = "https://oss-cn-hangzhou-internal.aliyuncs.com";
-    private String endpoint = "https://oss-cn-hangzhou.aliyuncs.com";
+    private String endpoint = "https://oss-cn-hangzhou-internal.aliyuncs.com";
+    //private String endpoint = "https://oss-cn-hangzhou.aliyuncs.com";
 
     public void upload(String bucketName, String localFile, String objectName) {
         OSS ossClient = new OSSClientBuilder().build(endpoint, accessId, accessKey);