丁云鹏 4 ヶ月 前
コミット
b0c62372ed

+ 106 - 2
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/i2i/I2IDSSMPredict.java

@@ -1,6 +1,20 @@
 package com.tzld.piaoquan.recommend.model.produce.i2i;
 
+import com.baidu.paddle.inference.Config;
+import com.baidu.paddle.inference.Predictor;
+import com.baidu.paddle.inference.Tensor;
+import com.tzld.piaoquan.recommend.model.produce.service.CMDService;
+import com.tzld.piaoquan.recommend.model.produce.service.OSSService;
+import com.tzld.piaoquan.recommend.model.produce.util.CompressUtil;
 import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.lang.math.NumberUtils;
+import org.apache.hadoop.io.compress.GzipCodec;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
+
+import java.util.Iterator;
+import java.util.Map;
 
 /**
  * @author dyp
@@ -10,7 +24,97 @@ public class I2IDSSMPredict {
 
     public static void main(String[] args) {
         System.loadLibrary("paddle_inference");
-        I2IDSSMService dssm = new I2IDSSMService();
-        dssm.predict(args);
+        CMDService cmd = new CMDService();
+        Map<String, String> argMap = cmd.parse(args);
+        String file = argMap.get("path");
+        int repartition = NumberUtils.toInt(argMap.get("repartition"), 64);
+
+        // 加载模型
+        SparkSession spark = SparkSession.builder()
+                .appName("I2IDSSMInfer")
+                .getOrCreate();
+
+        JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
+        JavaRDD<String> rdd = jsc.textFile(file);
+
+        // 定义处理数据的函数
+        JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> {
+            String bucketName = "art-recommend";
+            String objectName = "dyp/dssm.tar.gz";
+            OSSService ossService = new OSSService();
+
+            String gzPath = "/root/recommend-model/model.tar.gz";
+            ossService.download(bucketName, gzPath, objectName);
+            String modelDir = "/root/recommend-model";
+            CompressUtil.decompressGzFile(gzPath, modelDir);
+
+            String modelFile = modelDir + "/dssm.pdmodel";
+            String paramFile = modelDir + "/dssm.pdiparams";
+
+            Config config = new Config();
+            config.setCppModel(modelFile, paramFile);
+            config.enableMemoryOptim(true);
+            config.enableMKLDNN();
+            config.switchIrDebug(false);
+
+            Predictor predictor = Predictor.createPaddlePredictor(config);
+
+
+            return new Iterator<String>() {
+                private final Iterator<String> iterator = lines;
+
+                @Override
+                public boolean hasNext() {
+                    return iterator.hasNext();
+                }
+
+                @Override
+                public String next() {
+                    // 1 处理数据
+                    String line = lines.next();
+                    String[] sampleValues = line.split("\t", -1); // -1参数保持尾部空字符串
+
+                    // 检查是否有至少两个元素(vid和left_features_str)
+                    if (sampleValues.length >= 2) {
+                        String vid = sampleValues[0];
+                        String leftFeaturesStr = sampleValues[1];
+
+                        // 分割left_features_str并转换为float数组
+                        String[] leftFeaturesArray = leftFeaturesStr.split(",");
+                        float[] leftFeatures = new float[leftFeaturesArray.length];
+                        for (int i = 0; i < leftFeaturesArray.length; i++) {
+                            leftFeatures[i] = Float.parseFloat(leftFeaturesArray[i]);
+                        }
+                        String inNames = predictor.getInputNameById(0);
+                        Tensor inHandle = predictor.getInputHandle(inNames);
+                        // 2 设置输入
+                        inHandle.reshape(2, new int[]{1, 157});
+                        inHandle.copyFromCpu(leftFeatures);
+
+                        // 3 预测
+                        predictor.run();
+
+                        // 4 获取输入Tensor
+                        String outNames = predictor.getOutputNameById(0);
+                        Tensor outHandle = predictor.getOutputHandle(outNames);
+                        float[] outData = new float[outHandle.getSize()];
+                        outHandle.copyToCpu(outData);
+
+                        String result = vid + "\t" + outData[0];
+
+                        outHandle.destroyNativeTensor();
+                        inHandle.destroyNativeTensor();
+                        predictor.destroyNativePredictor();
+
+                        return result;
+                    }
+                    return "";
+                }
+            };
+        });
+        // 将处理后的数据写入新的文件,使用Gzip压缩
+        String outputPath = "hdfs:/dyp/vec2";
+        processedRdd.coalesce(repartition).saveAsTextFile(outputPath, GzipCodec.class);
     }
+
 }