| 
					
				 | 
			
			
				@@ -0,0 +1,133 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+package com.tzld.piaoquan.recommend.model.produce.i2i; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import com.baidu.paddle.inference.Config; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import com.baidu.paddle.inference.Predictor; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import com.baidu.paddle.inference.Tensor; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import com.tzld.piaoquan.recommend.model.produce.service.CMDService; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import com.tzld.piaoquan.recommend.model.produce.service.HDFSService; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import com.tzld.piaoquan.recommend.model.produce.service.OSSService; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import com.tzld.piaoquan.recommend.model.produce.util.CompressUtil; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import lombok.extern.slf4j.Slf4j; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.commons.lang.math.NumberUtils; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.commons.lang3.StringUtils; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.hadoop.io.compress.GzipCodec; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.spark.api.java.JavaRDD; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.spark.api.java.JavaSparkContext; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.apache.spark.sql.SparkSession; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import java.io.IOException; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import java.util.Iterator; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import java.util.Map; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+/** 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ * @author dyp 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ */ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+@Slf4j 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+public class I2IMilvusDataImport { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    private static HDFSService hdfsService = new HDFSService(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    public static void main(String[] args) throws IOException { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        CMDService cmd = new CMDService(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        Map<String, String> argMap = cmd.parse(args); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        String file = argMap.get("vecOutputPath"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        String modelOssObjectName = argMap.get("modelOssObjectName"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        String modelOssBucketName = argMap.get("modelOssBucketName"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        String modelHdfsSavePath = argMap.get("modelHdfsSavePath"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        String vecOutputPath = argMap.get("vecOutputPath"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        int repartition = NumberUtils.toInt(argMap.get("repartition"), 64); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        OSSService ossService = new OSSService(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        String gzPath = "/root/recommend-model/model.tar.gz"; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        ossService.download(modelOssBucketName, gzPath, modelOssObjectName); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        hdfsService.copyFromLocalFile(gzPath, modelHdfsSavePath); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        // 加载模型 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        SparkSession spark = SparkSession.builder() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                .appName("I2IDSSMInfer") 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                .getOrCreate(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        JavaRDD<String> rdd = jsc.textFile(file); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        // 定义处理数据的函数 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            System.loadLibrary("paddle_inference"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            hdfsService.copyToLocalFile(modelHdfsSavePath, "./model.tar.gz"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            CompressUtil.decompressGzFile("./model.tar.gz", "."); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            String modelFile = "dssm.pdmodel"; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            String paramFile = "dssm.pdiparams"; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Config config = new Config(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            config.setCppModel(modelFile, paramFile); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            config.enableMemoryOptim(true); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            config.enableMKLDNN(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            config.switchIrDebug(false); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            Predictor predictor = Predictor.createPaddlePredictor(config); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            return new Iterator<String>() { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                private final Iterator<String> iterator = lines; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                @Override 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                public boolean hasNext() { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    return iterator.hasNext(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                @Override 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                public String next() { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    // 1 处理数据 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    String line = lines.next(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    String[] sampleValues = line.split("\t", -1); // -1参数保持尾部空字符串 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    // 检查是否有至少两个元素(vid和left_features_str) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    if (sampleValues.length >= 2) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        String vid = sampleValues[0]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        String leftFeaturesStr = sampleValues[1]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        // 分割left_features_str并转换为float数组 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        String[] leftFeaturesArray = leftFeaturesStr.split(","); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        float[] leftFeatures = new float[leftFeaturesArray.length]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        for (int i = 0; i < leftFeaturesArray.length; i++) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                            leftFeatures[i] = Float.parseFloat(leftFeaturesArray[i]); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        String inNames = predictor.getInputNameById(0); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        Tensor inHandle = predictor.getInputHandle(inNames); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        // 2 设置输入 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        inHandle.reshape(2, new int[]{1, 157}); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        inHandle.copyFromCpu(leftFeatures); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        // 3 预测 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        predictor.run(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        // 4 获取输入Tensor 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        String outNames = predictor.getOutputNameById(0); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        Tensor outHandle = predictor.getOutputHandle(outNames); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        float[] outData = new float[outHandle.getSize()]; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        outHandle.copyToCpu(outData); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        String result = vid + "\t[" + StringUtils.join(outData, ',') + "]"; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        outHandle.destroyNativeTensor(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        inHandle.destroyNativeTensor(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                        return result; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                    return ""; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+                } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            }; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        }); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        // 将处理后的数据写入新的文件,使用Gzip压缩 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        try { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            hdfsService.deleteIfExist(vecOutputPath); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } catch (Exception e) { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+            log.error("deleteIfExist error outputPath {}", vecOutputPath, e); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+        processedRdd.coalesce(repartition).saveAsTextFile(vecOutputPath, GzipCodec.class); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+    } 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+} 
			 |