|
@@ -1,22 +1,26 @@
|
|
|
package com.tzld.piaoquan.recommend.model.produce.i2i;
|
|
|
|
|
|
-import com.baidu.paddle.inference.Config;
|
|
|
-import com.baidu.paddle.inference.Predictor;
|
|
|
-import com.baidu.paddle.inference.Tensor;
|
|
|
+import com.alibaba.fastjson.JSONObject;
|
|
|
+import io.milvus.client.*;
|
|
|
+import io.milvus.grpc.MutationResult;
|
|
|
+import io.milvus.param.*;
|
|
|
+import io.milvus.param.dml.InsertParam;
|
|
|
+import io.milvus.param.dml.UpsertParam;
|
|
|
+
|
|
|
import com.tzld.piaoquan.recommend.model.produce.service.CMDService;
|
|
|
import com.tzld.piaoquan.recommend.model.produce.service.HDFSService;
|
|
|
import com.tzld.piaoquan.recommend.model.produce.service.OSSService;
|
|
|
-import com.tzld.piaoquan.recommend.model.produce.util.CompressUtil;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import org.apache.commons.lang.math.NumberUtils;
|
|
|
import org.apache.commons.lang3.StringUtils;
|
|
|
-import org.apache.hadoop.io.compress.GzipCodec;
|
|
|
import org.apache.spark.api.java.JavaRDD;
|
|
|
import org.apache.spark.api.java.JavaSparkContext;
|
|
|
import org.apache.spark.sql.SparkSession;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
-import java.util.Iterator;
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.List;
|
|
|
import java.util.Map;
|
|
|
|
|
|
/**
|
|
@@ -27,107 +31,98 @@ public class I2IMilvusDataImport {
|
|
|
|
|
|
private static HDFSService hdfsService = new HDFSService();
|
|
|
|
|
|
+
|
|
|
public static void main(String[] args) throws IOException {
|
|
|
CMDService cmd = new CMDService();
|
|
|
Map<String, String> argMap = cmd.parse(args);
|
|
|
- String file = argMap.get("vecOutputPath");
|
|
|
- String modelOssObjectName = argMap.get("modelOssObjectName");
|
|
|
- String modelOssBucketName = argMap.get("modelOssBucketName");
|
|
|
- String modelHdfsSavePath = argMap.get("modelHdfsSavePath");
|
|
|
- String vecOutputPath = argMap.get("vecOutputPath");
|
|
|
-
|
|
|
- int repartition = NumberUtils.toInt(argMap.get("repartition"), 64);
|
|
|
-
|
|
|
- OSSService ossService = new OSSService();
|
|
|
- String gzPath = "/root/recommend-model/model.tar.gz";
|
|
|
- ossService.download(modelOssBucketName, gzPath, modelOssObjectName);
|
|
|
+ String file = argMap.get("dataPath");
|
|
|
+ String milvusUrl = argMap.get("milvusUrl");
|
|
|
+ String milvusToken = argMap.get("milvusToken");
|
|
|
+ String milvusCollection = argMap.get("milvusCollection");
|
|
|
+ int batchSize = NumberUtils.toInt(argMap.get("batchSize"), 5000);
|
|
|
|
|
|
- hdfsService.copyFromLocalFile(gzPath, modelHdfsSavePath);
|
|
|
- // 加载模型
|
|
|
SparkSession spark = SparkSession.builder()
|
|
|
- .appName("I2IDSSMInfer")
|
|
|
+ .appName("I2IMilvusDataImport")
|
|
|
.getOrCreate();
|
|
|
|
|
|
JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
|
|
|
JavaRDD<String> rdd = jsc.textFile(file);
|
|
|
|
|
|
// 定义处理数据的函数
|
|
|
- JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> {
|
|
|
- System.loadLibrary("paddle_inference");
|
|
|
- hdfsService.copyToLocalFile(modelHdfsSavePath, "./model.tar.gz");
|
|
|
- CompressUtil.decompressGzFile("./model.tar.gz", ".");
|
|
|
-
|
|
|
- String modelFile = "dssm.pdmodel";
|
|
|
- String paramFile = "dssm.pdiparams";
|
|
|
-
|
|
|
- Config config = new Config();
|
|
|
- config.setCppModel(modelFile, paramFile);
|
|
|
- config.enableMemoryOptim(true);
|
|
|
- config.enableMKLDNN();
|
|
|
- config.switchIrDebug(false);
|
|
|
-
|
|
|
- Predictor predictor = Predictor.createPaddlePredictor(config);
|
|
|
-
|
|
|
-
|
|
|
- return new Iterator<String>() {
|
|
|
- private final Iterator<String> iterator = lines;
|
|
|
-
|
|
|
- @Override
|
|
|
- public boolean hasNext() {
|
|
|
- return iterator.hasNext();
|
|
|
+ rdd.foreachPartition(lines -> {
|
|
|
+ ConnectParam connectParam = ConnectParam.newBuilder()
|
|
|
+ .withUri(milvusUrl)
|
|
|
+ .withToken(milvusToken)
|
|
|
+ .build();
|
|
|
+ RetryParam retryParam = RetryParam.newBuilder()
|
|
|
+ .withMaxRetryTimes(3)
|
|
|
+ .build();
|
|
|
+ MilvusClient milvusClient = new MilvusServiceClient(connectParam).withRetry(retryParam);
|
|
|
+ List<String> batch = new ArrayList<>();
|
|
|
+ while (lines.hasNext()) {
|
|
|
+ String line = lines.next();
|
|
|
+ batch.add(line);
|
|
|
+ // 如果批量数据达到指定大小,则插入到Milvus
|
|
|
+ if (batch.size() >= batchSize) {
|
|
|
+ upsertToMilvus(milvusClient, batch, milvusCollection);
|
|
|
+ batch.clear();
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- @Override
|
|
|
- public String next() {
|
|
|
- // 1 处理数据
|
|
|
- String line = lines.next();
|
|
|
- String[] sampleValues = line.split("\t", -1); // -1参数保持尾部空字符串
|
|
|
-
|
|
|
- // 检查是否有至少两个元素(vid和left_features_str)
|
|
|
- if (sampleValues.length >= 2) {
|
|
|
- String vid = sampleValues[0];
|
|
|
- String leftFeaturesStr = sampleValues[1];
|
|
|
-
|
|
|
- // 分割left_features_str并转换为float数组
|
|
|
- String[] leftFeaturesArray = leftFeaturesStr.split(",");
|
|
|
- float[] leftFeatures = new float[leftFeaturesArray.length];
|
|
|
- for (int i = 0; i < leftFeaturesArray.length; i++) {
|
|
|
- leftFeatures[i] = Float.parseFloat(leftFeaturesArray[i]);
|
|
|
- }
|
|
|
- String inNames = predictor.getInputNameById(0);
|
|
|
- Tensor inHandle = predictor.getInputHandle(inNames);
|
|
|
- // 2 设置输入
|
|
|
- inHandle.reshape(2, new int[]{1, 157});
|
|
|
- inHandle.copyFromCpu(leftFeatures);
|
|
|
-
|
|
|
- // 3 预测
|
|
|
- predictor.run();
|
|
|
-
|
|
|
- // 4 获取输入Tensor
|
|
|
- String outNames = predictor.getOutputNameById(0);
|
|
|
- Tensor outHandle = predictor.getOutputHandle(outNames);
|
|
|
- float[] outData = new float[outHandle.getSize()];
|
|
|
- outHandle.copyToCpu(outData);
|
|
|
-
|
|
|
-
|
|
|
- String result = vid + "\t[" + StringUtils.join(outData, ',') + "]";
|
|
|
-
|
|
|
- outHandle.destroyNativeTensor();
|
|
|
- inHandle.destroyNativeTensor();
|
|
|
-
|
|
|
- return result;
|
|
|
- }
|
|
|
- return "";
|
|
|
- }
|
|
|
- };
|
|
|
+ // 插入剩余的部分
|
|
|
+ if (!batch.isEmpty()) {
|
|
|
+ upsertToMilvus(milvusClient, batch, milvusCollection);
|
|
|
+ }
|
|
|
});
|
|
|
- // 将处理后的数据写入新的文件,使用Gzip压缩
|
|
|
- try {
|
|
|
- hdfsService.deleteIfExist(vecOutputPath);
|
|
|
- } catch (Exception e) {
|
|
|
- log.error("deleteIfExist error outputPath {}", vecOutputPath, e);
|
|
|
+ }
|
|
|
+
|
|
|
+ // 将数据批量 Upsert 到 Milvus
|
|
|
+ private static void upsertToMilvus(MilvusClient milvusClient, List<String> batch, String milvusCollection) {
|
|
|
+ List<UpsertParam.Field> fields = new ArrayList<>();
|
|
|
+ List<Long> ids = new ArrayList<>();
|
|
|
+ List<List<Float>> vectors = new ArrayList<>();
|
|
|
+
|
|
|
+ // 处理每条记录,假设每条记录包含 id 和 vector 字段,其他字段为动态列
|
|
|
+ Map<String, List<Object>> dynamicFields = new HashMap<>();
|
|
|
+
|
|
|
+ for (String record : batch) {
|
|
|
+ // 获取 ID 和向量
|
|
|
+ String[] values = StringUtils.split(record, "\t");
|
|
|
+ ids.add(Long.valueOf(values[0]));
|
|
|
+ vectors.add(JSONObject.parseArray(values[1], Float.class));
|
|
|
+
|
|
|
+// // 遍历剩余字段(动态列)
|
|
|
+// for (Map.Entry<String, Object> entry : record.entrySet()) {
|
|
|
+// String key = entry.getKey();
|
|
|
+// if (!key.equals("id") && !key.equals("vec")) {
|
|
|
+// // 如果动态字段不存在,则初始化
|
|
|
+// dynamicFields.computeIfAbsent(key, k -> new ArrayList<>()).add(entry.getValue());
|
|
|
+// }
|
|
|
+// }
|
|
|
+ }
|
|
|
+//
|
|
|
+// // 将动态字段添加到 Milvus 插入字段中
|
|
|
+// for (Map.Entry<String, List<Object>> dynamicEntry : dynamicFields.entrySet()) {
|
|
|
+// fields.add(new UpsertParam.Field(dynamicEntry.getKey(), dynamicEntry.getValue()));
|
|
|
+// }
|
|
|
+
|
|
|
+ // 将 id 和 vector 添加到字段中
|
|
|
+ fields.add(new UpsertParam.Field("vid", ids));
|
|
|
+ fields.add(new UpsertParam.Field("vec", vectors));
|
|
|
+
|
|
|
+ // 创建插入参数,设置 upsert 为 true
|
|
|
+ UpsertParam upsertParam = UpsertParam.newBuilder()
|
|
|
+ .withCollectionName(milvusCollection)
|
|
|
+ .withFields(fields)
|
|
|
+ .build();
|
|
|
+
|
|
|
+ // 执行 upsert
|
|
|
+ R<MutationResult> response = milvusClient.upsert(upsertParam);
|
|
|
+ if (response.getStatus().intValue() != 0) {
|
|
|
+ System.err.println("批量 Upsert 失败: " + response.getException().getMessage());
|
|
|
+ } else {
|
|
|
+ System.out.println("成功批量 Upsert 数据!");
|
|
|
}
|
|
|
- processedRdd.coalesce(repartition).saveAsTextFile(vecOutputPath, GzipCodec.class);
|
|
|
}
|
|
|
|
|
|
}
|