|
@@ -1,22 +1,26 @@
|
|
|
package com.tzld.piaoquan.recommend.model.produce.i2i;
|
|
|
|
|
|
-import com.baidu.paddle.inference.Config;
|
|
|
-import com.baidu.paddle.inference.Predictor;
|
|
|
-import com.baidu.paddle.inference.Tensor;
|
|
|
+import com.alibaba.fastjson.JSONObject;
|
|
|
+import io.milvus.client.*;
|
|
|
+import io.milvus.grpc.MutationResult;
|
|
|
+import io.milvus.param.*;
|
|
|
+import io.milvus.param.dml.InsertParam;
|
|
|
+import io.milvus.param.dml.UpsertParam;
|
|
|
+
|
|
|
import com.tzld.piaoquan.recommend.model.produce.service.CMDService;
|
|
|
import com.tzld.piaoquan.recommend.model.produce.service.HDFSService;
|
|
|
import com.tzld.piaoquan.recommend.model.produce.service.OSSService;
|
|
|
-import com.tzld.piaoquan.recommend.model.produce.util.CompressUtil;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import org.apache.commons.lang.math.NumberUtils;
|
|
|
import org.apache.commons.lang3.StringUtils;
|
|
|
-import org.apache.hadoop.io.compress.GzipCodec;
|
|
|
import org.apache.spark.api.java.JavaRDD;
|
|
|
import org.apache.spark.api.java.JavaSparkContext;
|
|
|
import org.apache.spark.sql.SparkSession;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
-import java.util.Iterator;
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.List;
|
|
|
import java.util.Map;
|
|
|
|
|
|
|
|
@@ -27,107 +31,98 @@ public class I2IMilvusDataImport {
|
|
|
|
|
|
private static HDFSService hdfsService = new HDFSService();
|
|
|
|
|
|
+
|
|
|
public static void main(String[] args) throws IOException {
|
|
|
CMDService cmd = new CMDService();
|
|
|
Map<String, String> argMap = cmd.parse(args);
|
|
|
- String file = argMap.get("vecOutputPath");
|
|
|
- String modelOssObjectName = argMap.get("modelOssObjectName");
|
|
|
- String modelOssBucketName = argMap.get("modelOssBucketName");
|
|
|
- String modelHdfsSavePath = argMap.get("modelHdfsSavePath");
|
|
|
- String vecOutputPath = argMap.get("vecOutputPath");
|
|
|
-
|
|
|
- int repartition = NumberUtils.toInt(argMap.get("repartition"), 64);
|
|
|
-
|
|
|
- OSSService ossService = new OSSService();
|
|
|
- String gzPath = "/root/recommend-model/model.tar.gz";
|
|
|
- ossService.download(modelOssBucketName, gzPath, modelOssObjectName);
|
|
|
+ String file = argMap.get("dataPath");
|
|
|
+ String milvusUrl = argMap.get("milvusUrl");
|
|
|
+ String milvusToken = argMap.get("milvusToken");
|
|
|
+ String milvusCollection = argMap.get("milvusCollection");
|
|
|
+ int batchSize = NumberUtils.toInt(argMap.get("batchSize"), 5000);
|
|
|
|
|
|
- hdfsService.copyFromLocalFile(gzPath, modelHdfsSavePath);
|
|
|
-
|
|
|
SparkSession spark = SparkSession.builder()
|
|
|
- .appName("I2IDSSMInfer")
|
|
|
+ .appName("I2IMilvusDataImport")
|
|
|
.getOrCreate();
|
|
|
|
|
|
JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
|
|
|
JavaRDD<String> rdd = jsc.textFile(file);
|
|
|
|
|
|
|
|
|
- JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> {
|
|
|
- System.loadLibrary("paddle_inference");
|
|
|
- hdfsService.copyToLocalFile(modelHdfsSavePath, "./model.tar.gz");
|
|
|
- CompressUtil.decompressGzFile("./model.tar.gz", ".");
|
|
|
-
|
|
|
- String modelFile = "dssm.pdmodel";
|
|
|
- String paramFile = "dssm.pdiparams";
|
|
|
-
|
|
|
- Config config = new Config();
|
|
|
- config.setCppModel(modelFile, paramFile);
|
|
|
- config.enableMemoryOptim(true);
|
|
|
- config.enableMKLDNN();
|
|
|
- config.switchIrDebug(false);
|
|
|
-
|
|
|
- Predictor predictor = Predictor.createPaddlePredictor(config);
|
|
|
-
|
|
|
-
|
|
|
- return new Iterator<String>() {
|
|
|
- private final Iterator<String> iterator = lines;
|
|
|
-
|
|
|
- @Override
|
|
|
- public boolean hasNext() {
|
|
|
- return iterator.hasNext();
|
|
|
+ rdd.foreachPartition(lines -> {
|
|
|
+ ConnectParam connectParam = ConnectParam.newBuilder()
|
|
|
+ .withUri(milvusUrl)
|
|
|
+ .withToken(milvusToken)
|
|
|
+ .build();
|
|
|
+ RetryParam retryParam = RetryParam.newBuilder()
|
|
|
+ .withMaxRetryTimes(3)
|
|
|
+ .build();
|
|
|
+ MilvusClient milvusClient = new MilvusServiceClient(connectParam).withRetry(retryParam);
|
|
|
+ List<String> batch = new ArrayList<>();
|
|
|
+ while (lines.hasNext()) {
|
|
|
+ String line = lines.next();
|
|
|
+ batch.add(line);
|
|
|
+
|
|
|
+ if (batch.size() >= batchSize) {
|
|
|
+ upsertToMilvus(milvusClient, batch, milvusCollection);
|
|
|
+ batch.clear();
|
|
|
}
|
|
|
+ }
|
|
|
|
|
|
- @Override
|
|
|
- public String next() {
|
|
|
-
|
|
|
- String line = lines.next();
|
|
|
- String[] sampleValues = line.split("\t", -1);
|
|
|
-
|
|
|
-
|
|
|
- if (sampleValues.length >= 2) {
|
|
|
- String vid = sampleValues[0];
|
|
|
- String leftFeaturesStr = sampleValues[1];
|
|
|
-
|
|
|
-
|
|
|
- String[] leftFeaturesArray = leftFeaturesStr.split(",");
|
|
|
- float[] leftFeatures = new float[leftFeaturesArray.length];
|
|
|
- for (int i = 0; i < leftFeaturesArray.length; i++) {
|
|
|
- leftFeatures[i] = Float.parseFloat(leftFeaturesArray[i]);
|
|
|
- }
|
|
|
- String inNames = predictor.getInputNameById(0);
|
|
|
- Tensor inHandle = predictor.getInputHandle(inNames);
|
|
|
-
|
|
|
- inHandle.reshape(2, new int[]{1, 157});
|
|
|
- inHandle.copyFromCpu(leftFeatures);
|
|
|
-
|
|
|
-
|
|
|
- predictor.run();
|
|
|
-
|
|
|
-
|
|
|
- String outNames = predictor.getOutputNameById(0);
|
|
|
- Tensor outHandle = predictor.getOutputHandle(outNames);
|
|
|
- float[] outData = new float[outHandle.getSize()];
|
|
|
- outHandle.copyToCpu(outData);
|
|
|
-
|
|
|
-
|
|
|
- String result = vid + "\t[" + StringUtils.join(outData, ',') + "]";
|
|
|
-
|
|
|
- outHandle.destroyNativeTensor();
|
|
|
- inHandle.destroyNativeTensor();
|
|
|
-
|
|
|
- return result;
|
|
|
- }
|
|
|
- return "";
|
|
|
- }
|
|
|
- };
|
|
|
+
|
|
|
+ if (!batch.isEmpty()) {
|
|
|
+ upsertToMilvus(milvusClient, batch, milvusCollection);
|
|
|
+ }
|
|
|
});
|
|
|
-
|
|
|
- try {
|
|
|
- hdfsService.deleteIfExist(vecOutputPath);
|
|
|
- } catch (Exception e) {
|
|
|
- log.error("deleteIfExist error outputPath {}", vecOutputPath, e);
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ private static void upsertToMilvus(MilvusClient milvusClient, List<String> batch, String milvusCollection) {
|
|
|
+ List<UpsertParam.Field> fields = new ArrayList<>();
|
|
|
+ List<Long> ids = new ArrayList<>();
|
|
|
+ List<List<Float>> vectors = new ArrayList<>();
|
|
|
+
|
|
|
+
|
|
|
+ Map<String, List<Object>> dynamicFields = new HashMap<>();
|
|
|
+
|
|
|
+ for (String record : batch) {
|
|
|
+
|
|
|
+ String[] values = StringUtils.split(record, "\t");
|
|
|
+ ids.add(Long.valueOf(values[0]));
|
|
|
+ vectors.add(JSONObject.parseArray(values[1], Float.class));
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ fields.add(new UpsertParam.Field("vid", ids));
|
|
|
+ fields.add(new UpsertParam.Field("vec", vectors));
|
|
|
+
|
|
|
+
|
|
|
+ UpsertParam upsertParam = UpsertParam.newBuilder()
|
|
|
+ .withCollectionName(milvusCollection)
|
|
|
+ .withFields(fields)
|
|
|
+ .build();
|
|
|
+
|
|
|
+
|
|
|
+ R<MutationResult> response = milvusClient.upsert(upsertParam);
|
|
|
+ if (response.getStatus().intValue() != 0) {
|
|
|
+ System.err.println("批量 Upsert 失败: " + response.getException().getMessage());
|
|
|
+ } else {
|
|
|
+ System.out.println("成功批量 Upsert 数据!");
|
|
|
}
|
|
|
- processedRdd.coalesce(repartition).saveAsTextFile(vecOutputPath, GzipCodec.class);
|
|
|
}
|
|
|
|
|
|
}
|