|  | @@ -1,133 +0,0 @@
 | 
											
												
													
														|  | -package com.tzld.piaoquan.recommend.model.produce.i2i;
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -import com.baidu.paddle.inference.Config;
 |  | 
 | 
											
												
													
														|  | -import com.baidu.paddle.inference.Predictor;
 |  | 
 | 
											
												
													
														|  | -import com.baidu.paddle.inference.Tensor;
 |  | 
 | 
											
												
													
														|  | -import com.tzld.piaoquan.recommend.model.produce.service.CMDService;
 |  | 
 | 
											
												
													
														|  | -import com.tzld.piaoquan.recommend.model.produce.service.OSSService;
 |  | 
 | 
											
												
													
														|  | -import com.tzld.piaoquan.recommend.model.produce.util.CompressUtil;
 |  | 
 | 
											
												
													
														|  | -import lombok.extern.slf4j.Slf4j;
 |  | 
 | 
											
												
													
														|  | -import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
 |  | 
 | 
											
												
													
														|  | -import org.apache.commons.lang.math.NumberUtils;
 |  | 
 | 
											
												
													
														|  | -import org.apache.commons.lang3.StringUtils;
 |  | 
 | 
											
												
													
														|  | -import org.apache.hadoop.io.compress.GzipCodec;
 |  | 
 | 
											
												
													
														|  | -import org.apache.spark.api.java.JavaRDD;
 |  | 
 | 
											
												
													
														|  | -import org.apache.spark.api.java.JavaSparkContext;
 |  | 
 | 
											
												
													
														|  | -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
 |  | 
 | 
											
												
													
														|  | -import org.apache.spark.ml.feature.VectorAssembler;
 |  | 
 | 
											
												
													
														|  | -import org.apache.spark.sql.Dataset;
 |  | 
 | 
											
												
													
														|  | -import org.apache.spark.sql.Row;
 |  | 
 | 
											
												
													
														|  | -import org.apache.spark.sql.RowFactory;
 |  | 
 | 
											
												
													
														|  | -import org.apache.spark.sql.SparkSession;
 |  | 
 | 
											
												
													
														|  | -import org.apache.spark.sql.types.DataTypes;
 |  | 
 | 
											
												
													
														|  | -import org.apache.spark.sql.types.StructField;
 |  | 
 | 
											
												
													
														|  | -import org.apache.spark.sql.types.StructType;
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -import java.io.Serializable;
 |  | 
 | 
											
												
													
														|  | -import java.util.*;
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -/**
 |  | 
 | 
											
												
													
														|  | - * @author dyp
 |  | 
 | 
											
												
													
														|  | - */
 |  | 
 | 
											
												
													
														|  | -@Slf4j
 |  | 
 | 
											
												
													
														|  | -public class I2IDSSMService implements Serializable {
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -    public void predict(String[] args) {
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -        CMDService cmd = new CMDService();
 |  | 
 | 
											
												
													
														|  | -        Map<String, String> argMap = cmd.parse(args);
 |  | 
 | 
											
												
													
														|  | -        String file = argMap.get("path");
 |  | 
 | 
											
												
													
														|  | -        int repartition = NumberUtils.toInt(argMap.get("repartition"), 64);
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -        // 加载模型
 |  | 
 | 
											
												
													
														|  | -        SparkSession spark = SparkSession.builder()
 |  | 
 | 
											
												
													
														|  | -                .appName("I2IDSSMInfer")
 |  | 
 | 
											
												
													
														|  | -                .getOrCreate();
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -        JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
 |  | 
 | 
											
												
													
														|  | -        JavaRDD<String> rdd = jsc.textFile(file);
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -        // 定义处理数据的函数
 |  | 
 | 
											
												
													
														|  | -        JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> {
 |  | 
 | 
											
												
													
														|  | -            String bucketName = "art-recommend";
 |  | 
 | 
											
												
													
														|  | -            String objectName = "dyp/dssm.tar.gz";
 |  | 
 | 
											
												
													
														|  | -            OSSService ossService = new OSSService();
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -            String gzPath = "/root/recommend-model/model.tar.gz";
 |  | 
 | 
											
												
													
														|  | -            ossService.download(bucketName, gzPath, objectName);
 |  | 
 | 
											
												
													
														|  | -            String modelDir = "/root/recommend-model";
 |  | 
 | 
											
												
													
														|  | -            CompressUtil.decompressGzFile(gzPath, modelDir);
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -            String modelFile = modelDir + "/dssm.pdmodel";
 |  | 
 | 
											
												
													
														|  | -            String paramFile = modelDir + "/dssm.pdiparams";
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -            Config config = new Config();
 |  | 
 | 
											
												
													
														|  | -            config.setCppModel(modelFile, paramFile);
 |  | 
 | 
											
												
													
														|  | -            config.enableMemoryOptim(true);
 |  | 
 | 
											
												
													
														|  | -            config.enableMKLDNN();
 |  | 
 | 
											
												
													
														|  | -            config.switchIrDebug(false);
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -            Predictor predictor = Predictor.createPaddlePredictor(config);
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -            return new Iterator<String>() {
 |  | 
 | 
											
												
													
														|  | -                private final Iterator<String> iterator = lines;
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -                @Override
 |  | 
 | 
											
												
													
														|  | -                public boolean hasNext() {
 |  | 
 | 
											
												
													
														|  | -                    return iterator.hasNext();
 |  | 
 | 
											
												
													
														|  | -                }
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -                @Override
 |  | 
 | 
											
												
													
														|  | -                public String next() {
 |  | 
 | 
											
												
													
														|  | -                    return processLine(iterator.next(), predictor);
 |  | 
 | 
											
												
													
														|  | -                }
 |  | 
 | 
											
												
													
														|  | -            };
 |  | 
 | 
											
												
													
														|  | -        });
 |  | 
 | 
											
												
													
														|  | -        // 将处理后的数据写入新的文件,使用Gzip压缩
 |  | 
 | 
											
												
													
														|  | -        String outputPath = "hdfs:/dyp/vec2";
 |  | 
 | 
											
												
													
														|  | -        processedRdd.coalesce(repartition).saveAsTextFile(outputPath, GzipCodec.class);
 |  | 
 | 
											
												
													
														|  | -    }
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -    private String processLine(String line, Predictor predictor) {
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -        // 1 处理数据
 |  | 
 | 
											
												
													
														|  | -        String[] sampleValues = line.split("\t", -1); // -1参数保持尾部空字符串
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -        // 检查是否有至少两个元素(vid和left_features_str)
 |  | 
 | 
											
												
													
														|  | -        if (sampleValues.length >= 2) {
 |  | 
 | 
											
												
													
														|  | -            String vid = sampleValues[0];
 |  | 
 | 
											
												
													
														|  | -            String leftFeaturesStr = sampleValues[1];
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -            // 分割left_features_str并转换为float数组
 |  | 
 | 
											
												
													
														|  | -            String[] leftFeaturesArray = leftFeaturesStr.split(",");
 |  | 
 | 
											
												
													
														|  | -            float[] leftFeatures = new float[leftFeaturesArray.length];
 |  | 
 | 
											
												
													
														|  | -            for (int i = 0; i < leftFeaturesArray.length; i++) {
 |  | 
 | 
											
												
													
														|  | -                leftFeatures[i] = Float.parseFloat(leftFeaturesArray[i]);
 |  | 
 | 
											
												
													
														|  | -            }
 |  | 
 | 
											
												
													
														|  | -            String inNames = predictor.getInputNameById(0);
 |  | 
 | 
											
												
													
														|  | -            Tensor inHandle = predictor.getInputHandle(inNames);
 |  | 
 | 
											
												
													
														|  | -            // 2 设置输入
 |  | 
 | 
											
												
													
														|  | -            inHandle.reshape(2, new int[]{1, 157});
 |  | 
 | 
											
												
													
														|  | -            inHandle.copyFromCpu(leftFeatures);
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -            // 3 预测
 |  | 
 | 
											
												
													
														|  | -            predictor.run();
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -            // 4 获取输入Tensor
 |  | 
 | 
											
												
													
														|  | -            String outNames = predictor.getOutputNameById(0);
 |  | 
 | 
											
												
													
														|  | -            Tensor outHandle = predictor.getOutputHandle(outNames);
 |  | 
 | 
											
												
													
														|  | -            float[] outData = new float[outHandle.getSize()];
 |  | 
 | 
											
												
													
														|  | -            outHandle.copyToCpu(outData);
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -            String result = vid + "\t" + outData[0];
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -            outHandle.destroyNativeTensor();
 |  | 
 | 
											
												
													
														|  | -            inHandle.destroyNativeTensor();
 |  | 
 | 
											
												
													
														|  | -            predictor.destroyNativePredictor();
 |  | 
 | 
											
												
													
														|  | -
 |  | 
 | 
											
												
													
														|  | -            return result;
 |  | 
 | 
											
												
													
														|  | -        }
 |  | 
 | 
											
												
													
														|  | -        return "";
 |  | 
 | 
											
												
													
														|  | -    }
 |  | 
 | 
											
												
													
														|  | -}
 |  | 
 |