丁云鹏 4 meses atrás
pai
commit
991a553cf5

+ 83 - 82
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/i2i/I2IDSSMPredict.java

@@ -43,88 +43,89 @@ public class I2IDSSMPredict {
 
         // 定义处理数据的函数
         JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> {
-            System.loadLibrary("paddle_inference");
-            String bucketName = "art-recommend";
-            String objectName = "dyp/dssm.tar.gz";
-            OSSService ossService = new OSSService();
-
-            String gzPath = "/root/recommend-model/model.tar.gz";
-            ossService.download(bucketName, gzPath, objectName);
-            String modelDir = "/root/recommend-model";
-            CompressUtil.decompressGzFile(gzPath, modelDir);
-
-            String modelFile = modelDir + "/dssm.pdmodel";
-            String paramFile = modelDir + "/dssm.pdiparams";
-
-            Config config = new Config();
-            config.setCppModel(modelFile, paramFile);
-            config.enableMemoryOptim(true);
-            config.enableMKLDNN();
-            config.switchIrDebug(false);
-
-            Predictor predictor = Predictor.createPaddlePredictor(config);
-
-
-            return new Iterator<String>() {
-                private final Iterator<String> iterator = lines;
-
-                @Override
-                public boolean hasNext() {
-                    return iterator.hasNext();
-                }
-
-                @Override
-                public String next() {
-                    // 1 处理数据
-                    String line = lines.next();
-                    String[] sampleValues = line.split("\t", -1); // -1参数保持尾部空字符串
-
-                    // 检查是否有至少两个元素(vid和left_features_str)
-                    if (sampleValues.length >= 2) {
-                        String vid = sampleValues[0];
-                        String leftFeaturesStr = sampleValues[1];
-
-                        // 分割left_features_str并转换为float数组
-                        String[] leftFeaturesArray = leftFeaturesStr.split(",");
-                        float[] leftFeatures = new float[leftFeaturesArray.length];
-                        for (int i = 0; i < leftFeaturesArray.length; i++) {
-                            leftFeatures[i] = Float.parseFloat(leftFeaturesArray[i]);
-                        }
-                        String inNames = predictor.getInputNameById(0);
-                        Tensor inHandle = predictor.getInputHandle(inNames);
-                        // 2 设置输入
-                        inHandle.reshape(2, new int[]{1, 157});
-                        inHandle.copyFromCpu(leftFeatures);
-
-                        // 3 预测
-                        predictor.run();
-
-                        // 4 获取输入Tensor
-                        String outNames = predictor.getOutputNameById(0);
-                        Tensor outHandle = predictor.getOutputHandle(outNames);
-                        float[] outData = new float[outHandle.getSize()];
-                        outHandle.copyToCpu(outData);
-
-
-                        String result = vid + "\t[" + StringUtils.join(outData, ',') + "]";
-
-                        outHandle.destroyNativeTensor();
-                        inHandle.destroyNativeTensor();
-
-                        return result;
-                    }
-                    return "";
-                }
-            };
-        });
-        // 将处理后的数据写入新的文件,使用Gzip压缩
-        String outputPath = "hdfs:/dyp/vec2";
-        try {
-            hdfsService.deleteIfExist(outputPath);
-        } catch (Exception e) {
-            log.error("deleteOnExit error outputPath {}", outputPath, e);
-        }
-        processedRdd.coalesce(repartition).saveAsTextFile(outputPath, GzipCodec.class);
+                    System.loadLibrary("paddle_inference");
+                    String bucketName = "art-recommend";
+                    String objectName = "dyp/dssm.tar.gz";
+                    OSSService ossService = new OSSService();
+
+                    String gzPath = "/root/recommend-model/model.tar.gz";
+                    ossService.download(bucketName, gzPath, objectName);
+                });
+//            String modelDir = "/root/recommend-model";
+//            CompressUtil.decompressGzFile(gzPath, modelDir);
+//
+//            String modelFile = modelDir + "/dssm.pdmodel";
+//            String paramFile = modelDir + "/dssm.pdiparams";
+//
+//            Config config = new Config();
+//            config.setCppModel(modelFile, paramFile);
+//            config.enableMemoryOptim(true);
+//            config.enableMKLDNN();
+//            config.switchIrDebug(false);
+//
+//            Predictor predictor = Predictor.createPaddlePredictor(config);
+//
+//
+//            return new Iterator<String>() {
+//                private final Iterator<String> iterator = lines;
+//
+//                @Override
+//                public boolean hasNext() {
+//                    return iterator.hasNext();
+//                }
+//
+//                @Override
+//                public String next() {
+//                    // 1 处理数据
+//                    String line = lines.next();
+//                    String[] sampleValues = line.split("\t", -1); // -1参数保持尾部空字符串
+//
+//                    // 检查是否有至少两个元素(vid和left_features_str)
+//                    if (sampleValues.length >= 2) {
+//                        String vid = sampleValues[0];
+//                        String leftFeaturesStr = sampleValues[1];
+//
+//                        // 分割left_features_str并转换为float数组
+//                        String[] leftFeaturesArray = leftFeaturesStr.split(",");
+//                        float[] leftFeatures = new float[leftFeaturesArray.length];
+//                        for (int i = 0; i < leftFeaturesArray.length; i++) {
+//                            leftFeatures[i] = Float.parseFloat(leftFeaturesArray[i]);
+//                        }
+//                        String inNames = predictor.getInputNameById(0);
+//                        Tensor inHandle = predictor.getInputHandle(inNames);
+//                        // 2 设置输入
+//                        inHandle.reshape(2, new int[]{1, 157});
+//                        inHandle.copyFromCpu(leftFeatures);
+//
+//                        // 3 预测
+//                        predictor.run();
+//
+//                        // 4 获取输入Tensor
+//                        String outNames = predictor.getOutputNameById(0);
+//                        Tensor outHandle = predictor.getOutputHandle(outNames);
+//                        float[] outData = new float[outHandle.getSize()];
+//                        outHandle.copyToCpu(outData);
+//
+//
+//                        String result = vid + "\t[" + StringUtils.join(outData, ',') + "]";
+//
+//                        outHandle.destroyNativeTensor();
+//                        inHandle.destroyNativeTensor();
+//
+//                        return result;
+//                    }
+//                    return "";
+//                }
+//            };
+//        });
+//        // 将处理后的数据写入新的文件,使用Gzip压缩
+//        String outputPath = "hdfs:/dyp/vec2";
+//        try {
+//            hdfsService.deleteIfExist(outputPath);
+//        } catch (Exception e) {
+//            log.error("deleteOnExit error outputPath {}", outputPath, e);
+//        }
+//        processedRdd.coalesce(repartition).saveAsTextFile(outputPath, GzipCodec.class);
     }
 
 }