|
@@ -0,0 +1,133 @@
|
|
|
+package com.tzld.piaoquan.recommend.model.produce.i2i;
|
|
|
+
|
|
|
+import com.baidu.paddle.inference.Config;
|
|
|
+import com.baidu.paddle.inference.Predictor;
|
|
|
+import com.baidu.paddle.inference.Tensor;
|
|
|
+import com.tzld.piaoquan.recommend.model.produce.service.CMDService;
|
|
|
+import com.tzld.piaoquan.recommend.model.produce.service.HDFSService;
|
|
|
+import com.tzld.piaoquan.recommend.model.produce.service.OSSService;
|
|
|
+import com.tzld.piaoquan.recommend.model.produce.util.CompressUtil;
|
|
|
+import lombok.extern.slf4j.Slf4j;
|
|
|
+import org.apache.commons.lang.math.NumberUtils;
|
|
|
+import org.apache.commons.lang3.StringUtils;
|
|
|
+import org.apache.hadoop.io.compress.GzipCodec;
|
|
|
+import org.apache.spark.api.java.JavaRDD;
|
|
|
+import org.apache.spark.api.java.JavaSparkContext;
|
|
|
+import org.apache.spark.sql.SparkSession;
|
|
|
+
|
|
|
+import java.io.IOException;
|
|
|
+import java.util.Iterator;
|
|
|
+import java.util.Map;
|
|
|
+
|
|
|
+/**
|
|
|
+ * @author dyp
|
|
|
+ */
|
|
|
+@Slf4j
|
|
|
+public class I2IMilvusDataImport {
|
|
|
+
|
|
|
+ private static HDFSService hdfsService = new HDFSService();
|
|
|
+
|
|
|
+ public static void main(String[] args) throws IOException {
|
|
|
+ CMDService cmd = new CMDService();
|
|
|
+ Map<String, String> argMap = cmd.parse(args);
|
|
|
+ String file = argMap.get("vecOutputPath");
|
|
|
+ String modelOssObjectName = argMap.get("modelOssObjectName");
|
|
|
+ String modelOssBucketName = argMap.get("modelOssBucketName");
|
|
|
+ String modelHdfsSavePath = argMap.get("modelHdfsSavePath");
|
|
|
+ String vecOutputPath = argMap.get("vecOutputPath");
|
|
|
+
|
|
|
+ int repartition = NumberUtils.toInt(argMap.get("repartition"), 64);
|
|
|
+
|
|
|
+ OSSService ossService = new OSSService();
|
|
|
+ String gzPath = "/root/recommend-model/model.tar.gz";
|
|
|
+ ossService.download(modelOssBucketName, gzPath, modelOssObjectName);
|
|
|
+
|
|
|
+ hdfsService.copyFromLocalFile(gzPath, modelHdfsSavePath);
|
|
|
+ // 加载模型
|
|
|
+ SparkSession spark = SparkSession.builder()
|
|
|
+ .appName("I2IDSSMInfer")
|
|
|
+ .getOrCreate();
|
|
|
+
|
|
|
+ JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
|
|
|
+ JavaRDD<String> rdd = jsc.textFile(file);
|
|
|
+
|
|
|
+ // 定义处理数据的函数
|
|
|
+ JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> {
|
|
|
+ System.loadLibrary("paddle_inference");
|
|
|
+ hdfsService.copyToLocalFile(modelHdfsSavePath, "./model.tar.gz");
|
|
|
+ CompressUtil.decompressGzFile("./model.tar.gz", ".");
|
|
|
+
|
|
|
+ String modelFile = "dssm.pdmodel";
|
|
|
+ String paramFile = "dssm.pdiparams";
|
|
|
+
|
|
|
+ Config config = new Config();
|
|
|
+ config.setCppModel(modelFile, paramFile);
|
|
|
+ config.enableMemoryOptim(true);
|
|
|
+ config.enableMKLDNN();
|
|
|
+ config.switchIrDebug(false);
|
|
|
+
|
|
|
+ Predictor predictor = Predictor.createPaddlePredictor(config);
|
|
|
+
|
|
|
+
|
|
|
+ return new Iterator<String>() {
|
|
|
+ private final Iterator<String> iterator = lines;
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public boolean hasNext() {
|
|
|
+ return iterator.hasNext();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String next() {
|
|
|
+ // 1 处理数据
|
|
|
+ String line = lines.next();
|
|
|
+ String[] sampleValues = line.split("\t", -1); // -1参数保持尾部空字符串
|
|
|
+
|
|
|
+ // 检查是否有至少两个元素(vid和left_features_str)
|
|
|
+ if (sampleValues.length >= 2) {
|
|
|
+ String vid = sampleValues[0];
|
|
|
+ String leftFeaturesStr = sampleValues[1];
|
|
|
+
|
|
|
+ // 分割left_features_str并转换为float数组
|
|
|
+ String[] leftFeaturesArray = leftFeaturesStr.split(",");
|
|
|
+ float[] leftFeatures = new float[leftFeaturesArray.length];
|
|
|
+ for (int i = 0; i < leftFeaturesArray.length; i++) {
|
|
|
+ leftFeatures[i] = Float.parseFloat(leftFeaturesArray[i]);
|
|
|
+ }
|
|
|
+ String inNames = predictor.getInputNameById(0);
|
|
|
+ Tensor inHandle = predictor.getInputHandle(inNames);
|
|
|
+ // 2 设置输入
|
|
|
+ inHandle.reshape(2, new int[]{1, 157});
|
|
|
+ inHandle.copyFromCpu(leftFeatures);
|
|
|
+
|
|
|
+ // 3 预测
|
|
|
+ predictor.run();
|
|
|
+
|
|
|
+ // 4 获取输入Tensor
|
|
|
+ String outNames = predictor.getOutputNameById(0);
|
|
|
+ Tensor outHandle = predictor.getOutputHandle(outNames);
|
|
|
+ float[] outData = new float[outHandle.getSize()];
|
|
|
+ outHandle.copyToCpu(outData);
|
|
|
+
|
|
|
+
|
|
|
+ String result = vid + "\t[" + StringUtils.join(outData, ',') + "]";
|
|
|
+
|
|
|
+ outHandle.destroyNativeTensor();
|
|
|
+ inHandle.destroyNativeTensor();
|
|
|
+
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+ return "";
|
|
|
+ }
|
|
|
+ };
|
|
|
+ });
|
|
|
+ // 将处理后的数据写入新的文件,使用Gzip压缩
|
|
|
+ try {
|
|
|
+ hdfsService.deleteIfExist(vecOutputPath);
|
|
|
+ } catch (Exception e) {
|
|
|
+ log.error("deleteIfExist error outputPath {}", vecOutputPath, e);
|
|
|
+ }
|
|
|
+ processedRdd.coalesce(repartition).saveAsTextFile(vecOutputPath, GzipCodec.class);
|
|
|
+ }
|
|
|
+
|
|
|
+}
|