瀏覽代碼

dssm train

丁云鹏 5 月之前
父節點
當前提交
f914eaca8a

+ 14 - 19
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/i2i/I2IDSSMPredict.java

@@ -30,16 +30,19 @@ public class I2IDSSMPredict {
     public static void main(String[] args) throws IOException {
         CMDService cmd = new CMDService();
         Map<String, String> argMap = cmd.parse(args);
-        String file = argMap.get("path");
+        String file = argMap.get("predictDataPath");
+        String modelOssObjectName = argMap.get("modelOssObjectName");
+        String modelOssBucketName = argMap.get("modelOssBucketName");
+        String modelHdfsSavePath = argMap.get("modelHdfsSavePath");
+        String vecOutputPath = argMap.get("vecOutputPath");
+
         int repartition = NumberUtils.toInt(argMap.get("repartition"), 64);
-        String bucketName2 = "art-recommend";
-        String objectName2 = "dyp/dssm.tar.gz";
-        OSSService ossService2 = new OSSService();
 
-        String gzPath2 = "/root/recommend-model/model.tar.gz";
-        ossService2.download(bucketName2, gzPath2, objectName2);
+        OSSService ossService = new OSSService();
+        String gzPath = "/root/recommend-model/model.tar.gz";
+        ossService.download(modelOssBucketName, gzPath, modelOssObjectName);
 
-        hdfsService.copyFromLocalFile(gzPath2, "/dyp/dssm/model.tar.gz");
+        hdfsService.copyFromLocalFile(gzPath, modelHdfsSavePath);
         // 加载模型
         SparkSession spark = SparkSession.builder()
                 .appName("I2IDSSMInfer")
@@ -51,14 +54,7 @@ public class I2IDSSMPredict {
         // 定义处理数据的函数
         JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> {
             System.loadLibrary("paddle_inference");
-//            String bucketName = "art-recommend";
-//            String objectName = "dyp/dssm.tar.gz";
-//            OSSService ossService = new OSSService();
-//
-//            String gzPath = "/root/recommend-model/model.tar.gz";
-//            ossService.download(bucketName, gzPath, objectName);
-            // String modelDir = "/root/recommend-model";
-            hdfsService.copyToLocalFile("/dyp/dssm/model.tar.gz", "./model.tar.gz");
+            hdfsService.copyToLocalFile(modelHdfsSavePath, "./model.tar.gz");
             CompressUtil.decompressGzFile("./model.tar.gz", ".");
 
             String modelFile = "dssm.pdmodel";
@@ -126,13 +122,12 @@ public class I2IDSSMPredict {
             };
         });
         // 将处理后的数据写入新的文件,使用Gzip压缩
-        String outputPath = "hdfs:/dyp/vec2";
         try {
-            hdfsService.deleteIfExist(outputPath);
+            hdfsService.deleteIfExist(vecOutputPath);
         } catch (Exception e) {
-            log.error("deleteOnExit error outputPath {}", outputPath, e);
+            log.error("deleteIfExist error outputPath {}", vecOutputPath, e);
         }
-        processedRdd.coalesce(repartition).saveAsTextFile(outputPath, GzipCodec.class);
+        processedRdd.coalesce(repartition).saveAsTextFile(vecOutputPath, GzipCodec.class);
     }
 
 }

+ 0 - 133
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/i2i/I2IDSSMService.java

@@ -1,133 +0,0 @@
-package com.tzld.piaoquan.recommend.model.produce.i2i;
-
-import com.baidu.paddle.inference.Config;
-import com.baidu.paddle.inference.Predictor;
-import com.baidu.paddle.inference.Tensor;
-import com.tzld.piaoquan.recommend.model.produce.service.CMDService;
-import com.tzld.piaoquan.recommend.model.produce.service.OSSService;
-import com.tzld.piaoquan.recommend.model.produce.util.CompressUtil;
-import lombok.extern.slf4j.Slf4j;
-import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
-import org.apache.commons.lang.math.NumberUtils;
-import org.apache.commons.lang3.StringUtils;
-import org.apache.hadoop.io.compress.GzipCodec;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
-import org.apache.spark.ml.feature.VectorAssembler;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
-import org.apache.spark.sql.types.DataTypes;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
-
-import java.io.Serializable;
-import java.util.*;
-
-/**
- * @author dyp
- */
-@Slf4j
-public class I2IDSSMService implements Serializable {
-
-    public void predict(String[] args) {
-
-        CMDService cmd = new CMDService();
-        Map<String, String> argMap = cmd.parse(args);
-        String file = argMap.get("path");
-        int repartition = NumberUtils.toInt(argMap.get("repartition"), 64);
-
-        // 加载模型
-        SparkSession spark = SparkSession.builder()
-                .appName("I2IDSSMInfer")
-                .getOrCreate();
-
-        JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
-        JavaRDD<String> rdd = jsc.textFile(file);
-
-        // 定义处理数据的函数
-        JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> {
-            String bucketName = "art-recommend";
-            String objectName = "dyp/dssm.tar.gz";
-            OSSService ossService = new OSSService();
-
-            String gzPath = "/root/recommend-model/model.tar.gz";
-            ossService.download(bucketName, gzPath, objectName);
-            String modelDir = "/root/recommend-model";
-            CompressUtil.decompressGzFile(gzPath, modelDir);
-
-            String modelFile = modelDir + "/dssm.pdmodel";
-            String paramFile = modelDir + "/dssm.pdiparams";
-
-            Config config = new Config();
-            config.setCppModel(modelFile, paramFile);
-            config.enableMemoryOptim(true);
-            config.enableMKLDNN();
-            config.switchIrDebug(false);
-
-            Predictor predictor = Predictor.createPaddlePredictor(config);
-
-
-            return new Iterator<String>() {
-                private final Iterator<String> iterator = lines;
-
-                @Override
-                public boolean hasNext() {
-                    return iterator.hasNext();
-                }
-
-                @Override
-                public String next() {
-                    return processLine(iterator.next(), predictor);
-                }
-            };
-        });
-        // 将处理后的数据写入新的文件,使用Gzip压缩
-        String outputPath = "hdfs:/dyp/vec2";
-        processedRdd.coalesce(repartition).saveAsTextFile(outputPath, GzipCodec.class);
-    }
-
-    private String processLine(String line, Predictor predictor) {
-
-        // 1 处理数据
-        String[] sampleValues = line.split("\t", -1); // -1参数保持尾部空字符串
-
-        // 检查是否有至少两个元素(vid和left_features_str)
-        if (sampleValues.length >= 2) {
-            String vid = sampleValues[0];
-            String leftFeaturesStr = sampleValues[1];
-
-            // 分割left_features_str并转换为float数组
-            String[] leftFeaturesArray = leftFeaturesStr.split(",");
-            float[] leftFeatures = new float[leftFeaturesArray.length];
-            for (int i = 0; i < leftFeaturesArray.length; i++) {
-                leftFeatures[i] = Float.parseFloat(leftFeaturesArray[i]);
-            }
-            String inNames = predictor.getInputNameById(0);
-            Tensor inHandle = predictor.getInputHandle(inNames);
-            // 2 设置输入
-            inHandle.reshape(2, new int[]{1, 157});
-            inHandle.copyFromCpu(leftFeatures);
-
-            // 3 预测
-            predictor.run();
-
-            // 4 获取输入Tensor
-            String outNames = predictor.getOutputNameById(0);
-            Tensor outHandle = predictor.getOutputHandle(outNames);
-            float[] outData = new float[outHandle.getSize()];
-            outHandle.copyToCpu(outData);
-
-            String result = vid + "\t" + outData[0];
-
-            outHandle.destroyNativeTensor();
-            inHandle.destroyNativeTensor();
-            predictor.destroyNativePredictor();
-
-            return result;
-        }
-        return "";
-    }
-}