|  | @@ -43,89 +43,88 @@ public class I2IDSSMPredict {
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |          // 定义处理数据的函数
 |  |          // 定义处理数据的函数
 | 
											
												
													
														|  |          JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> {
 |  |          JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> {
 | 
											
												
													
														|  | -                    System.loadLibrary("paddle_inference");
 |  | 
 | 
											
												
													
														|  | -                    String bucketName = "art-recommend";
 |  | 
 | 
											
												
													
														|  | -                    String objectName = "dyp/dssm.tar.gz";
 |  | 
 | 
											
												
													
														|  | -                    OSSService ossService = new OSSService();
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -                    String gzPath = "/root/recommend-model/model.tar.gz";
 |  | 
 | 
											
												
													
														|  | -                    ossService.download(bucketName, gzPath, objectName);
 |  | 
 | 
											
												
													
														|  | -                });
 |  | 
 | 
											
												
													
														|  | -//            String modelDir = "/root/recommend-model";
 |  | 
 | 
											
												
													
														|  | -//            CompressUtil.decompressGzFile(gzPath, modelDir);
 |  | 
 | 
											
												
													
														|  | -//
 |  | 
 | 
											
												
													
														|  | -//            String modelFile = modelDir + "/dssm.pdmodel";
 |  | 
 | 
											
												
													
														|  | -//            String paramFile = modelDir + "/dssm.pdiparams";
 |  | 
 | 
											
												
													
														|  | -//
 |  | 
 | 
											
												
													
														|  | -//            Config config = new Config();
 |  | 
 | 
											
												
													
														|  | -//            config.setCppModel(modelFile, paramFile);
 |  | 
 | 
											
												
													
														|  | -//            config.enableMemoryOptim(true);
 |  | 
 | 
											
												
													
														|  | -//            config.enableMKLDNN();
 |  | 
 | 
											
												
													
														|  | -//            config.switchIrDebug(false);
 |  | 
 | 
											
												
													
														|  | -//
 |  | 
 | 
											
												
													
														|  | -//            Predictor predictor = Predictor.createPaddlePredictor(config);
 |  | 
 | 
											
												
													
														|  | -//
 |  | 
 | 
											
												
													
														|  | -//
 |  | 
 | 
											
												
													
														|  | -//            return new Iterator<String>() {
 |  | 
 | 
											
												
													
														|  | -//                private final Iterator<String> iterator = lines;
 |  | 
 | 
											
												
													
														|  | -//
 |  | 
 | 
											
												
													
														|  | -//                @Override
 |  | 
 | 
											
												
													
														|  | -//                public boolean hasNext() {
 |  | 
 | 
											
												
													
														|  | -//                    return iterator.hasNext();
 |  | 
 | 
											
												
													
														|  | -//                }
 |  | 
 | 
											
												
													
														|  | -//
 |  | 
 | 
											
												
													
														|  | -//                @Override
 |  | 
 | 
											
												
													
														|  | -//                public String next() {
 |  | 
 | 
											
												
													
														|  | -//                    // 1 处理数据
 |  | 
 | 
											
												
													
														|  | -//                    String line = lines.next();
 |  | 
 | 
											
												
													
														|  | -//                    String[] sampleValues = line.split("\t", -1); // -1参数保持尾部空字符串
 |  | 
 | 
											
												
													
														|  | -//
 |  | 
 | 
											
												
													
														|  | -//                    // 检查是否有至少两个元素(vid和left_features_str)
 |  | 
 | 
											
												
													
														|  | -//                    if (sampleValues.length >= 2) {
 |  | 
 | 
											
												
													
														|  | -//                        String vid = sampleValues[0];
 |  | 
 | 
											
												
													
														|  | -//                        String leftFeaturesStr = sampleValues[1];
 |  | 
 | 
											
												
													
														|  | -//
 |  | 
 | 
											
												
													
														|  | -//                        // 分割left_features_str并转换为float数组
 |  | 
 | 
											
												
													
														|  | -//                        String[] leftFeaturesArray = leftFeaturesStr.split(",");
 |  | 
 | 
											
												
													
														|  | -//                        float[] leftFeatures = new float[leftFeaturesArray.length];
 |  | 
 | 
											
												
													
														|  | -//                        for (int i = 0; i < leftFeaturesArray.length; i++) {
 |  | 
 | 
											
												
													
														|  | -//                            leftFeatures[i] = Float.parseFloat(leftFeaturesArray[i]);
 |  | 
 | 
											
												
													
														|  | -//                        }
 |  | 
 | 
											
												
													
														|  | -//                        String inNames = predictor.getInputNameById(0);
 |  | 
 | 
											
												
													
														|  | -//                        Tensor inHandle = predictor.getInputHandle(inNames);
 |  | 
 | 
											
												
													
														|  | -//                        // 2 设置输入
 |  | 
 | 
											
												
													
														|  | -//                        inHandle.reshape(2, new int[]{1, 157});
 |  | 
 | 
											
												
													
														|  | -//                        inHandle.copyFromCpu(leftFeatures);
 |  | 
 | 
											
												
													
														|  | -//
 |  | 
 | 
											
												
													
														|  | -//                        // 3 预测
 |  | 
 | 
											
												
													
														|  | -//                        predictor.run();
 |  | 
 | 
											
												
													
														|  | -//
 |  | 
 | 
											
												
													
														|  | -//                        // 4 获取输入Tensor
 |  | 
 | 
											
												
													
														|  | -//                        String outNames = predictor.getOutputNameById(0);
 |  | 
 | 
											
												
													
														|  | -//                        Tensor outHandle = predictor.getOutputHandle(outNames);
 |  | 
 | 
											
												
													
														|  | -//                        float[] outData = new float[outHandle.getSize()];
 |  | 
 | 
											
												
													
														|  | -//                        outHandle.copyToCpu(outData);
 |  | 
 | 
											
												
													
														|  | -//
 |  | 
 | 
											
												
													
														|  | -//
 |  | 
 | 
											
												
													
														|  | -//                        String result = vid + "\t[" + StringUtils.join(outData, ',') + "]";
 |  | 
 | 
											
												
													
														|  | -//
 |  | 
 | 
											
												
													
														|  | -//                        outHandle.destroyNativeTensor();
 |  | 
 | 
											
												
													
														|  | -//                        inHandle.destroyNativeTensor();
 |  | 
 | 
											
												
													
														|  | -//
 |  | 
 | 
											
												
													
														|  | -//                        return result;
 |  | 
 | 
											
												
													
														|  | -//                    }
 |  | 
 | 
											
												
													
														|  | -//                    return "";
 |  | 
 | 
											
												
													
														|  | -//                }
 |  | 
 | 
											
												
													
														|  | -//            };
 |  | 
 | 
											
												
													
														|  | -//        });
 |  | 
 | 
											
												
													
														|  | -//        // 将处理后的数据写入新的文件,使用Gzip压缩
 |  | 
 | 
											
												
													
														|  | -//        String outputPath = "hdfs:/dyp/vec2";
 |  | 
 | 
											
												
													
														|  | -//        try {
 |  | 
 | 
											
												
													
														|  | -//            hdfsService.deleteIfExist(outputPath);
 |  | 
 | 
											
												
													
														|  | -//        } catch (Exception e) {
 |  | 
 | 
											
												
													
														|  | -//            log.error("deleteOnExit error outputPath {}", outputPath, e);
 |  | 
 | 
											
												
													
														|  | -//        }
 |  | 
 | 
											
												
													
														|  | -//        processedRdd.coalesce(repartition).saveAsTextFile(outputPath, GzipCodec.class);
 |  | 
 | 
											
												
													
														|  | 
 |  | +            System.loadLibrary("paddle_inference");
 | 
											
												
													
														|  | 
 |  | +            String bucketName = "art-recommend";
 | 
											
												
													
														|  | 
 |  | +            String objectName = "dyp/dssm.tar.gz";
 | 
											
												
													
														|  | 
 |  | +            OSSService ossService = new OSSService();
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +            String gzPath = "/root/recommend-model/model.tar.gz";
 | 
											
												
													
														|  | 
 |  | +            ossService.download(bucketName, gzPath, objectName);
 | 
											
												
													
														|  | 
 |  | +            String modelDir = "/root/recommend-model";
 | 
											
												
													
														|  | 
 |  | +            CompressUtil.decompressGzFile(gzPath, modelDir);
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +            String modelFile = modelDir + "/dssm.pdmodel";
 | 
											
												
													
														|  | 
 |  | +            String paramFile = modelDir + "/dssm.pdiparams";
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +            Config config = new Config();
 | 
											
												
													
														|  | 
 |  | +            config.setCppModel(modelFile, paramFile);
 | 
											
												
													
														|  | 
 |  | +            config.enableMemoryOptim(true);
 | 
											
												
													
														|  | 
 |  | +            config.enableMKLDNN();
 | 
											
												
													
														|  | 
 |  | +            config.switchIrDebug(false);
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +            Predictor predictor = Predictor.createPaddlePredictor(config);
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +            return new Iterator<String>() {
 | 
											
												
													
														|  | 
 |  | +                private final Iterator<String> iterator = lines;
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +                @Override
 | 
											
												
													
														|  | 
 |  | +                public boolean hasNext() {
 | 
											
												
													
														|  | 
 |  | +                    return iterator.hasNext();
 | 
											
												
													
														|  | 
 |  | +                }
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +                @Override
 | 
											
												
													
														|  | 
 |  | +                public String next() {
 | 
											
												
													
														|  | 
 |  | +                    // 1 处理数据
 | 
											
												
													
														|  | 
 |  | +                    String line = lines.next();
 | 
											
												
													
														|  | 
 |  | +                    String[] sampleValues = line.split("\t", -1); // -1参数保持尾部空字符串
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +                    // 检查是否有至少两个元素(vid和left_features_str)
 | 
											
												
													
														|  | 
 |  | +                    if (sampleValues.length >= 2) {
 | 
											
												
													
														|  | 
 |  | +                        String vid = sampleValues[0];
 | 
											
												
													
														|  | 
 |  | +                        String leftFeaturesStr = sampleValues[1];
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +                        // 分割left_features_str并转换为float数组
 | 
											
												
													
														|  | 
 |  | +                        String[] leftFeaturesArray = leftFeaturesStr.split(",");
 | 
											
												
													
														|  | 
 |  | +                        float[] leftFeatures = new float[leftFeaturesArray.length];
 | 
											
												
													
														|  | 
 |  | +                        for (int i = 0; i < leftFeaturesArray.length; i++) {
 | 
											
												
													
														|  | 
 |  | +                            leftFeatures[i] = Float.parseFloat(leftFeaturesArray[i]);
 | 
											
												
													
														|  | 
 |  | +                        }
 | 
											
												
													
														|  | 
 |  | +                        String inNames = predictor.getInputNameById(0);
 | 
											
												
													
														|  | 
 |  | +                        Tensor inHandle = predictor.getInputHandle(inNames);
 | 
											
												
													
														|  | 
 |  | +                        // 2 设置输入
 | 
											
												
													
														|  | 
 |  | +                        inHandle.reshape(2, new int[]{1, 157});
 | 
											
												
													
														|  | 
 |  | +                        inHandle.copyFromCpu(leftFeatures);
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +                        // 3 预测
 | 
											
												
													
														|  | 
 |  | +                        predictor.run();
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +                        // 4 获取输入Tensor
 | 
											
												
													
														|  | 
 |  | +                        String outNames = predictor.getOutputNameById(0);
 | 
											
												
													
														|  | 
 |  | +                        Tensor outHandle = predictor.getOutputHandle(outNames);
 | 
											
												
													
														|  | 
 |  | +                        float[] outData = new float[outHandle.getSize()];
 | 
											
												
													
														|  | 
 |  | +                        outHandle.copyToCpu(outData);
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +                        String result = vid + "\t[" + StringUtils.join(outData, ',') + "]";
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +                        outHandle.destroyNativeTensor();
 | 
											
												
													
														|  | 
 |  | +                        inHandle.destroyNativeTensor();
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +                        return result;
 | 
											
												
													
														|  | 
 |  | +                    }
 | 
											
												
													
														|  | 
 |  | +                    return "";
 | 
											
												
													
														|  | 
 |  | +                }
 | 
											
												
													
														|  | 
 |  | +            };
 | 
											
												
													
														|  | 
 |  | +        });
 | 
											
												
													
														|  | 
 |  | +        // 将处理后的数据写入新的文件,使用Gzip压缩
 | 
											
												
													
														|  | 
 |  | +        String outputPath = "hdfs:/dyp/vec2";
 | 
											
												
													
														|  | 
 |  | +        try {
 | 
											
												
													
														|  | 
 |  | +            hdfsService.deleteIfExist(outputPath);
 | 
											
												
													
														|  | 
 |  | +        } catch (Exception e) {
 | 
											
												
													
														|  | 
 |  | +            log.error("deleteOnExit error outputPath {}", outputPath, e);
 | 
											
												
													
														|  | 
 |  | +        }
 | 
											
												
													
														|  | 
 |  | +        processedRdd.coalesce(repartition).saveAsTextFile(outputPath, GzipCodec.class);
 | 
											
												
													
														|  |      }
 |  |      }
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  }
 |  |  }
 |