丁云鹏 před 4 měsíci
rodič
revize
a8ec515755

+ 2 - 2
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/i2i/I2IDSSMPredict.java

@@ -90,7 +90,7 @@ public class I2IDSSMPredict {
 
                         // 分割left_features_str并转换为float数组
                         String[] leftFeaturesArray = leftFeaturesStr.split(",");
-                        float[] leftFeatures = new float[leftFeaturesArray.length];
+                        double[] leftFeatures = new double[leftFeaturesArray.length];
                         for (int i = 0; i < leftFeaturesArray.length; i++) {
                             leftFeatures[i] = Float.parseFloat(leftFeaturesArray[i]);
                         }
@@ -106,7 +106,7 @@ public class I2IDSSMPredict {
                         // 4 获取输入Tensor
                         String outNames = predictor.getOutputNameById(0);
                         Tensor outHandle = predictor.getOutputHandle(outNames);
-                        float[] outData = new float[outHandle.getSize()];
+                        double[] outData = new double[outHandle.getSize()];
                         outHandle.copyToCpu(outData);
 
 

+ 133 - 0
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/i2i/I2IMilvusDataImport.java

@@ -0,0 +1,133 @@
+package com.tzld.piaoquan.recommend.model.produce.i2i;
+
+import com.baidu.paddle.inference.Config;
+import com.baidu.paddle.inference.Predictor;
+import com.baidu.paddle.inference.Tensor;
+import com.tzld.piaoquan.recommend.model.produce.service.CMDService;
+import com.tzld.piaoquan.recommend.model.produce.service.HDFSService;
+import com.tzld.piaoquan.recommend.model.produce.service.OSSService;
+import com.tzld.piaoquan.recommend.model.produce.util.CompressUtil;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.lang.math.NumberUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.hadoop.io.compress.GzipCodec;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.Map;
+
+/**
+ * @author dyp
+ */
+@Slf4j
+public class I2IMilvusDataImport {
+
+    private static HDFSService hdfsService = new HDFSService();
+
+    public static void main(String[] args) throws IOException {
+        CMDService cmd = new CMDService();
+        Map<String, String> argMap = cmd.parse(args);
+        String file = argMap.get("vecOutputPath");
+        String modelOssObjectName = argMap.get("modelOssObjectName");
+        String modelOssBucketName = argMap.get("modelOssBucketName");
+        String modelHdfsSavePath = argMap.get("modelHdfsSavePath");
+        String vecOutputPath = argMap.get("vecOutputPath");
+
+        int repartition = NumberUtils.toInt(argMap.get("repartition"), 64);
+
+        OSSService ossService = new OSSService();
+        String gzPath = "/root/recommend-model/model.tar.gz";
+        ossService.download(modelOssBucketName, gzPath, modelOssObjectName);
+
+        hdfsService.copyFromLocalFile(gzPath, modelHdfsSavePath);
+        // 加载模型
+        SparkSession spark = SparkSession.builder()
+                .appName("I2IDSSMInfer")
+                .getOrCreate();
+
+        JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
+        JavaRDD<String> rdd = jsc.textFile(file);
+
+        // 定义处理数据的函数
+        JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> {
+            System.loadLibrary("paddle_inference");
+            hdfsService.copyToLocalFile(modelHdfsSavePath, "./model.tar.gz");
+            CompressUtil.decompressGzFile("./model.tar.gz", ".");
+
+            String modelFile = "dssm.pdmodel";
+            String paramFile = "dssm.pdiparams";
+
+            Config config = new Config();
+            config.setCppModel(modelFile, paramFile);
+            config.enableMemoryOptim(true);
+            config.enableMKLDNN();
+            config.switchIrDebug(false);
+
+            Predictor predictor = Predictor.createPaddlePredictor(config);
+
+
+            return new Iterator<String>() {
+                private final Iterator<String> iterator = lines;
+
+                @Override
+                public boolean hasNext() {
+                    return iterator.hasNext();
+                }
+
+                @Override
+                public String next() {
+                    // 1 处理数据
+                    String line = lines.next();
+                    String[] sampleValues = line.split("\t", -1); // -1参数保持尾部空字符串
+
+                    // 检查是否有至少两个元素(vid和left_features_str)
+                    if (sampleValues.length >= 2) {
+                        String vid = sampleValues[0];
+                        String leftFeaturesStr = sampleValues[1];
+
+                        // 分割left_features_str并转换为float数组
+                        String[] leftFeaturesArray = leftFeaturesStr.split(",");
+                        float[] leftFeatures = new float[leftFeaturesArray.length];
+                        for (int i = 0; i < leftFeaturesArray.length; i++) {
+                            leftFeatures[i] = Float.parseFloat(leftFeaturesArray[i]);
+                        }
+                        String inNames = predictor.getInputNameById(0);
+                        Tensor inHandle = predictor.getInputHandle(inNames);
+                        // 2 设置输入
+                        inHandle.reshape(2, new int[]{1, 157});
+                        inHandle.copyFromCpu(leftFeatures);
+
+                        // 3 预测
+                        predictor.run();
+
+                        // 4 获取输入Tensor
+                        String outNames = predictor.getOutputNameById(0);
+                        Tensor outHandle = predictor.getOutputHandle(outNames);
+                        float[] outData = new float[outHandle.getSize()];
+                        outHandle.copyToCpu(outData);
+
+
+                        String result = vid + "\t[" + StringUtils.join(outData, ',') + "]";
+
+                        outHandle.destroyNativeTensor();
+                        inHandle.destroyNativeTensor();
+
+                        return result;
+                    }
+                    return "";
+                }
+            };
+        });
+        // 将处理后的数据写入新的文件,使用Gzip压缩
+        try {
+            hdfsService.deleteIfExist(vecOutputPath);
+        } catch (Exception e) {
+            log.error("deleteIfExist error outputPath {}", vecOutputPath, e);
+        }
+        processedRdd.coalesce(repartition).saveAsTextFile(vecOutputPath, GzipCodec.class);
+    }
+
+}