Browse Source

dssm train

丁云鹏 4 months ago
parent
commit
468b217577

+ 82 - 83
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/i2i/I2IDSSMPredict.java

@@ -43,89 +43,88 @@ public class I2IDSSMPredict {
 
         // 定义处理数据的函数
         JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> {
-                    System.loadLibrary("paddle_inference");
-                    String bucketName = "art-recommend";
-                    String objectName = "dyp/dssm.tar.gz";
-                    OSSService ossService = new OSSService();
-
-                    String gzPath = "/root/recommend-model/model.tar.gz";
-                    ossService.download(bucketName, gzPath, objectName);
-                });
-//            String modelDir = "/root/recommend-model";
-//            CompressUtil.decompressGzFile(gzPath, modelDir);
-//
-//            String modelFile = modelDir + "/dssm.pdmodel";
-//            String paramFile = modelDir + "/dssm.pdiparams";
-//
-//            Config config = new Config();
-//            config.setCppModel(modelFile, paramFile);
-//            config.enableMemoryOptim(true);
-//            config.enableMKLDNN();
-//            config.switchIrDebug(false);
-//
-//            Predictor predictor = Predictor.createPaddlePredictor(config);
-//
-//
-//            return new Iterator<String>() {
-//                private final Iterator<String> iterator = lines;
-//
-//                @Override
-//                public boolean hasNext() {
-//                    return iterator.hasNext();
-//                }
-//
-//                @Override
-//                public String next() {
-//                    // 1 处理数据
-//                    String line = lines.next();
-//                    String[] sampleValues = line.split("\t", -1); // -1参数保持尾部空字符串
-//
-//                    // 检查是否有至少两个元素(vid和left_features_str)
-//                    if (sampleValues.length >= 2) {
-//                        String vid = sampleValues[0];
-//                        String leftFeaturesStr = sampleValues[1];
-//
-//                        // 分割left_features_str并转换为float数组
-//                        String[] leftFeaturesArray = leftFeaturesStr.split(",");
-//                        float[] leftFeatures = new float[leftFeaturesArray.length];
-//                        for (int i = 0; i < leftFeaturesArray.length; i++) {
-//                            leftFeatures[i] = Float.parseFloat(leftFeaturesArray[i]);
-//                        }
-//                        String inNames = predictor.getInputNameById(0);
-//                        Tensor inHandle = predictor.getInputHandle(inNames);
-//                        // 2 设置输入
-//                        inHandle.reshape(2, new int[]{1, 157});
-//                        inHandle.copyFromCpu(leftFeatures);
-//
-//                        // 3 预测
-//                        predictor.run();
-//
-//                        // 4 获取输入Tensor
-//                        String outNames = predictor.getOutputNameById(0);
-//                        Tensor outHandle = predictor.getOutputHandle(outNames);
-//                        float[] outData = new float[outHandle.getSize()];
-//                        outHandle.copyToCpu(outData);
-//
-//
-//                        String result = vid + "\t[" + StringUtils.join(outData, ',') + "]";
-//
-//                        outHandle.destroyNativeTensor();
-//                        inHandle.destroyNativeTensor();
-//
-//                        return result;
-//                    }
-//                    return "";
-//                }
-//            };
-//        });
-//        // 将处理后的数据写入新的文件,使用Gzip压缩
-//        String outputPath = "hdfs:/dyp/vec2";
-//        try {
-//            hdfsService.deleteIfExist(outputPath);
-//        } catch (Exception e) {
-//            log.error("deleteOnExit error outputPath {}", outputPath, e);
-//        }
-//        processedRdd.coalesce(repartition).saveAsTextFile(outputPath, GzipCodec.class);
+            System.loadLibrary("paddle_inference");
+            String bucketName = "art-recommend";
+            String objectName = "dyp/dssm.tar.gz";
+            OSSService ossService = new OSSService();
+
+            String gzPath = "/root/recommend-model/model.tar.gz";
+            ossService.download(bucketName, gzPath, objectName);
+            String modelDir = "/root/recommend-model";
+            CompressUtil.decompressGzFile(gzPath, modelDir);
+
+            String modelFile = modelDir + "/dssm.pdmodel";
+            String paramFile = modelDir + "/dssm.pdiparams";
+
+            Config config = new Config();
+            config.setCppModel(modelFile, paramFile);
+            config.enableMemoryOptim(true);
+            config.enableMKLDNN();
+            config.switchIrDebug(false);
+
+            Predictor predictor = Predictor.createPaddlePredictor(config);
+
+
+            return new Iterator<String>() {
+                private final Iterator<String> iterator = lines;
+
+                @Override
+                public boolean hasNext() {
+                    return iterator.hasNext();
+                }
+
+                @Override
+                public String next() {
+                    // 1 处理数据
+                    String line = lines.next();
+                    String[] sampleValues = line.split("\t", -1); // -1参数保持尾部空字符串
+
+                    // 检查是否有至少两个元素(vid和left_features_str)
+                    if (sampleValues.length >= 2) {
+                        String vid = sampleValues[0];
+                        String leftFeaturesStr = sampleValues[1];
+
+                        // 分割left_features_str并转换为float数组
+                        String[] leftFeaturesArray = leftFeaturesStr.split(",");
+                        float[] leftFeatures = new float[leftFeaturesArray.length];
+                        for (int i = 0; i < leftFeaturesArray.length; i++) {
+                            leftFeatures[i] = Float.parseFloat(leftFeaturesArray[i]);
+                        }
+                        String inNames = predictor.getInputNameById(0);
+                        Tensor inHandle = predictor.getInputHandle(inNames);
+                        // 2 设置输入
+                        inHandle.reshape(2, new int[]{1, 157});
+                        inHandle.copyFromCpu(leftFeatures);
+
+                        // 3 预测
+                        predictor.run();
+
+                        // 4 获取输入Tensor
+                        String outNames = predictor.getOutputNameById(0);
+                        Tensor outHandle = predictor.getOutputHandle(outNames);
+                        float[] outData = new float[outHandle.getSize()];
+                        outHandle.copyToCpu(outData);
+
+
+                        String result = vid + "\t[" + StringUtils.join(outData, ',') + "]";
+
+                        outHandle.destroyNativeTensor();
+                        inHandle.destroyNativeTensor();
+
+                        return result;
+                    }
+                    return "";
+                }
+            };
+        });
+        // 将处理后的数据写入新的文件,使用Gzip压缩
+        String outputPath = "hdfs:/dyp/vec2";
+        try {
+            hdfsService.deleteIfExist(outputPath);
+        } catch (Exception e) {
+            log.error("deleteOnExit error outputPath {}", outputPath, e);
+        }
+        processedRdd.coalesce(repartition).saveAsTextFile(outputPath, GzipCodec.class);
     }
 
 }

+ 2 - 2
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/service/OSSService.java

@@ -17,8 +17,8 @@ import java.io.Serializable;
 public class OSSService implements Serializable {
     private String accessId = "LTAI5tHMkNaRhpiDB1yWMZPn";
     private String accessKey = "XLi5YUJusVwbbQOaGeGsaRJ1Qyzbui";
-    private String endpoint = "https://oss-cn-hangzhou-internal.aliyuncs.com";
-    //private String endpoint = "https://oss-cn-hangzhou.aliyuncs.com";
+    //private String endpoint = "https://oss-cn-hangzhou-internal.aliyuncs.com";
+    private String endpoint = "https://oss-cn-hangzhou.aliyuncs.com";
 
     public void upload(String bucketName, String localFile, String objectName) {
         OSS ossClient = new OSSClientBuilder().build(endpoint, accessId, accessKey);