| 
														
															@@ -1,6 +1,20 @@ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 package com.tzld.piaoquan.recommend.model.produce.i2i; 
														 | 
														
														 | 
														
															 package com.tzld.piaoquan.recommend.model.produce.i2i; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+import com.baidu.paddle.inference.Config; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+import com.baidu.paddle.inference.Predictor; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+import com.baidu.paddle.inference.Tensor; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+import com.tzld.piaoquan.recommend.model.produce.service.CMDService; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+import com.tzld.piaoquan.recommend.model.produce.service.OSSService; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+import com.tzld.piaoquan.recommend.model.produce.util.CompressUtil; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 import lombok.extern.slf4j.Slf4j; 
														 | 
														
														 | 
														
															 import lombok.extern.slf4j.Slf4j; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+import org.apache.commons.lang.math.NumberUtils; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+import org.apache.hadoop.io.compress.GzipCodec; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+import org.apache.spark.api.java.JavaRDD; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+import org.apache.spark.api.java.JavaSparkContext; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+import org.apache.spark.sql.SparkSession; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+import java.util.Iterator; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+import java.util.Map; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															 /** 
														 | 
														
														 | 
														
															 /** 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  * @author dyp 
														 | 
														
														 | 
														
															  * @author dyp 
														 | 
													
												
											
										
											
												
													
														 | 
														
															@@ -10,7 +24,97 @@ public class I2IDSSMPredict { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															  
														 | 
														
														 | 
														
															  
														 | 
													
												
											
												
													
														| 
														 | 
														
															     public static void main(String[] args) { 
														 | 
														
														 | 
														
															     public static void main(String[] args) { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															         System.loadLibrary("paddle_inference"); 
														 | 
														
														 | 
														
															         System.loadLibrary("paddle_inference"); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        I2IDSSMService dssm = new I2IDSSMService(); 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															-        dssm.predict(args); 
														 | 
														
														 | 
														
															 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        CMDService cmd = new CMDService(); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        Map<String, String> argMap = cmd.parse(args); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        String file = argMap.get("path"); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        int repartition = NumberUtils.toInt(argMap.get("repartition"), 64); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        // 加载模型 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        SparkSession spark = SparkSession.builder() 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                .appName("I2IDSSMInfer") 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                .getOrCreate(); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext()); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        JavaRDD<String> rdd = jsc.textFile(file); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        // 定义处理数据的函数 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            String bucketName = "art-recommend"; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            String objectName = "dyp/dssm.tar.gz"; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            OSSService ossService = new OSSService(); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            String gzPath = "/root/recommend-model/model.tar.gz"; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            ossService.download(bucketName, gzPath, objectName); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            String modelDir = "/root/recommend-model"; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            CompressUtil.decompressGzFile(gzPath, modelDir); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            String modelFile = modelDir + "/dssm.pdmodel"; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            String paramFile = modelDir + "/dssm.pdiparams"; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            Config config = new Config(); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            config.setCppModel(modelFile, paramFile); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            config.enableMemoryOptim(true); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            config.enableMKLDNN(); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            config.switchIrDebug(false); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            Predictor predictor = Predictor.createPaddlePredictor(config); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            return new Iterator<String>() { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                private final Iterator<String> iterator = lines; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                @Override 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                public boolean hasNext() { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                    return iterator.hasNext(); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                } 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                @Override 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                public String next() { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                    // 1 处理数据 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                    String line = lines.next(); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                    String[] sampleValues = line.split("\t", -1); // -1参数保持尾部空字符串 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                    // 检查是否有至少两个元素(vid和left_features_str) 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                    if (sampleValues.length >= 2) { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        String vid = sampleValues[0]; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        String leftFeaturesStr = sampleValues[1]; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        // 分割left_features_str并转换为float数组 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        String[] leftFeaturesArray = leftFeaturesStr.split(","); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        float[] leftFeatures = new float[leftFeaturesArray.length]; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        for (int i = 0; i < leftFeaturesArray.length; i++) { 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                            leftFeatures[i] = Float.parseFloat(leftFeaturesArray[i]); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        } 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        String inNames = predictor.getInputNameById(0); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        Tensor inHandle = predictor.getInputHandle(inNames); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        // 2 设置输入 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        inHandle.reshape(2, new int[]{1, 157}); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        inHandle.copyFromCpu(leftFeatures); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        // 3 预测 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        predictor.run(); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        // 4 获取输入Tensor 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        String outNames = predictor.getOutputNameById(0); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        Tensor outHandle = predictor.getOutputHandle(outNames); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        float[] outData = new float[outHandle.getSize()]; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        outHandle.copyToCpu(outData); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        String result = vid + "\t" + outData[0]; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        outHandle.destroyNativeTensor(); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        inHandle.destroyNativeTensor(); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        predictor.destroyNativePredictor(); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                        return result; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                    } 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                    return ""; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+                } 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+            }; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        }); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        // 将处理后的数据写入新的文件,使用Gzip压缩 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        String outputPath = "hdfs:/dyp/vec2"; 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+        processedRdd.coalesce(repartition).saveAsTextFile(outputPath, GzipCodec.class); 
														 | 
													
												
											
												
													
														| 
														 | 
														
															     } 
														 | 
														
														 | 
														
															     } 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 
														 | 
														
														 | 
														
															+ 
														 | 
													
												
											
												
													
														| 
														 | 
														
															 } 
														 | 
														
														 | 
														
															 } 
														 |