소스 검색

dssm train

丁云鹏 10 달 전
부모
커밋
21dbe986a5

+ 6 - 0
recommend-model-produce/pom.xml

@@ -178,6 +178,12 @@
             <artifactId>fastjson</artifactId>
             <version>${fastjson.version}</version>
         </dependency>
+        <!-- Milvus SDK -->
+        <dependency>
+            <groupId>io.milvus</groupId>
+            <artifactId>milvus-sdk-java</artifactId>
+            <version>2.4.9</version>
+        </dependency>
         <dependency>
             <groupId>org.projectlombok</groupId>
             <artifactId>lombok</artifactId>

+ 87 - 92
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/i2i/I2IMilvusDataImport.java

@@ -1,22 +1,26 @@
 package com.tzld.piaoquan.recommend.model.produce.i2i;
 
-import com.baidu.paddle.inference.Config;
-import com.baidu.paddle.inference.Predictor;
-import com.baidu.paddle.inference.Tensor;
+import com.alibaba.fastjson.JSONObject;
+import io.milvus.client.*;
+import io.milvus.grpc.MutationResult;
+import io.milvus.param.*;
+import io.milvus.param.dml.InsertParam;
+import io.milvus.param.dml.UpsertParam;
+
 import com.tzld.piaoquan.recommend.model.produce.service.CMDService;
 import com.tzld.piaoquan.recommend.model.produce.service.HDFSService;
 import com.tzld.piaoquan.recommend.model.produce.service.OSSService;
-import com.tzld.piaoquan.recommend.model.produce.util.CompressUtil;
 import lombok.extern.slf4j.Slf4j;
 import org.apache.commons.lang.math.NumberUtils;
 import org.apache.commons.lang3.StringUtils;
-import org.apache.hadoop.io.compress.GzipCodec;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.SparkSession;
 
 import java.io.IOException;
-import java.util.Iterator;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
 /**
@@ -27,107 +31,98 @@ public class I2IMilvusDataImport {
 
     private static HDFSService hdfsService = new HDFSService();
 
+
     public static void main(String[] args) throws IOException {
         CMDService cmd = new CMDService();
         Map<String, String> argMap = cmd.parse(args);
-        String file = argMap.get("vecOutputPath");
-        String modelOssObjectName = argMap.get("modelOssObjectName");
-        String modelOssBucketName = argMap.get("modelOssBucketName");
-        String modelHdfsSavePath = argMap.get("modelHdfsSavePath");
-        String vecOutputPath = argMap.get("vecOutputPath");
-
-        int repartition = NumberUtils.toInt(argMap.get("repartition"), 64);
-
-        OSSService ossService = new OSSService();
-        String gzPath = "/root/recommend-model/model.tar.gz";
-        ossService.download(modelOssBucketName, gzPath, modelOssObjectName);
+        String file = argMap.get("dataPath");
+        String milvusUrl = argMap.get("milvusUrl");
+        String milvusToken = argMap.get("milvusToken");
+        String milvusCollection = argMap.get("milvusCollection");
+        int batchSize = NumberUtils.toInt(argMap.get("batchSize"), 5000);
 
-        hdfsService.copyFromLocalFile(gzPath, modelHdfsSavePath);
-        // 加载模型
         SparkSession spark = SparkSession.builder()
-                .appName("I2IDSSMInfer")
+                .appName("I2IMilvusDataImport")
                 .getOrCreate();
 
         JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
         JavaRDD<String> rdd = jsc.textFile(file);
 
         // 定义处理数据的函数
-        JavaRDD<String> processedRdd = rdd.mapPartitions(lines -> {
-            System.loadLibrary("paddle_inference");
-            hdfsService.copyToLocalFile(modelHdfsSavePath, "./model.tar.gz");
-            CompressUtil.decompressGzFile("./model.tar.gz", ".");
-
-            String modelFile = "dssm.pdmodel";
-            String paramFile = "dssm.pdiparams";
-
-            Config config = new Config();
-            config.setCppModel(modelFile, paramFile);
-            config.enableMemoryOptim(true);
-            config.enableMKLDNN();
-            config.switchIrDebug(false);
-
-            Predictor predictor = Predictor.createPaddlePredictor(config);
-
-
-            return new Iterator<String>() {
-                private final Iterator<String> iterator = lines;
-
-                @Override
-                public boolean hasNext() {
-                    return iterator.hasNext();
+        rdd.foreachPartition(lines -> {
+            ConnectParam connectParam = ConnectParam.newBuilder()
+                    .withUri(milvusUrl)
+                    .withToken(milvusToken)
+                    .build();
+            RetryParam retryParam = RetryParam.newBuilder()
+                    .withMaxRetryTimes(3)
+                    .build();
+            MilvusClient milvusClient = new MilvusServiceClient(connectParam).withRetry(retryParam);
+            List<String> batch = new ArrayList<>();
+            while (lines.hasNext()) {
+                String line = lines.next();
+                batch.add(line);
+                // 如果批量数据达到指定大小,则插入到Milvus
+                if (batch.size() >= batchSize) {
+                    upsertToMilvus(milvusClient, batch, milvusCollection);
+                    batch.clear();
                 }
+            }
 
-                @Override
-                public String next() {
-                    // 1 处理数据
-                    String line = lines.next();
-                    String[] sampleValues = line.split("\t", -1); // -1参数保持尾部空字符串
-
-                    // 检查是否有至少两个元素(vid和left_features_str)
-                    if (sampleValues.length >= 2) {
-                        String vid = sampleValues[0];
-                        String leftFeaturesStr = sampleValues[1];
-
-                        // 分割left_features_str并转换为float数组
-                        String[] leftFeaturesArray = leftFeaturesStr.split(",");
-                        float[] leftFeatures = new float[leftFeaturesArray.length];
-                        for (int i = 0; i < leftFeaturesArray.length; i++) {
-                            leftFeatures[i] = Float.parseFloat(leftFeaturesArray[i]);
-                        }
-                        String inNames = predictor.getInputNameById(0);
-                        Tensor inHandle = predictor.getInputHandle(inNames);
-                        // 2 设置输入
-                        inHandle.reshape(2, new int[]{1, 157});
-                        inHandle.copyFromCpu(leftFeatures);
-
-                        // 3 预测
-                        predictor.run();
-
-                        // 4 获取输入Tensor
-                        String outNames = predictor.getOutputNameById(0);
-                        Tensor outHandle = predictor.getOutputHandle(outNames);
-                        float[] outData = new float[outHandle.getSize()];
-                        outHandle.copyToCpu(outData);
-
-
-                        String result = vid + "\t[" + StringUtils.join(outData, ',') + "]";
-
-                        outHandle.destroyNativeTensor();
-                        inHandle.destroyNativeTensor();
-
-                        return result;
-                    }
-                    return "";
-                }
-            };
+            // 插入剩余的部分
+            if (!batch.isEmpty()) {
+                upsertToMilvus(milvusClient, batch, milvusCollection);
+            }
         });
-        // 将处理后的数据写入新的文件,使用Gzip压缩
-        try {
-            hdfsService.deleteIfExist(vecOutputPath);
-        } catch (Exception e) {
-            log.error("deleteIfExist error outputPath {}", vecOutputPath, e);
+    }
+
+    // 将数据批量 Upsert 到 Milvus
+    private static void upsertToMilvus(MilvusClient milvusClient, List<String> batch, String milvusCollection) {
+        List<UpsertParam.Field> fields = new ArrayList<>();
+        List<Long> ids = new ArrayList<>();
+        List<List<Float>> vectors = new ArrayList<>();
+
+        // 处理每条记录,假设每条记录包含 id 和 vector 字段,其他字段为动态列
+        Map<String, List<Object>> dynamicFields = new HashMap<>();
+
+        for (String record : batch) {
+            // 获取 ID 和向量
+            String[] values = StringUtils.split(record, "\t");
+            ids.add(Long.valueOf(values[0]));
+            vectors.add(JSONObject.parseArray(values[1], Float.class));
+
+//            // 遍历剩余字段(动态列)
+//            for (Map.Entry<String, Object> entry : record.entrySet()) {
+//                String key = entry.getKey();
+//                if (!key.equals("id") && !key.equals("vec")) {
+//                    // 如果动态字段不存在,则初始化
+//                    dynamicFields.computeIfAbsent(key, k -> new ArrayList<>()).add(entry.getValue());
+//                }
+//            }
+        }
+//
+//        // 将动态字段添加到 Milvus 插入字段中
+//        for (Map.Entry<String, List<Object>> dynamicEntry : dynamicFields.entrySet()) {
+//            fields.add(new UpsertParam.Field(dynamicEntry.getKey(), dynamicEntry.getValue()));
+//        }
+
+        // 将 id 和 vector 添加到字段中
+        fields.add(new UpsertParam.Field("vid", ids));
+        fields.add(new UpsertParam.Field("vec", vectors));
+
+        // 创建插入参数,设置 upsert 为 true
+        UpsertParam upsertParam = UpsertParam.newBuilder()
+                .withCollectionName(milvusCollection)
+                .withFields(fields)
+                .build();
+
+        // 执行 upsert
+        R<MutationResult> response = milvusClient.upsert(upsertParam);
+        if (response.getStatus().intValue() != 0) {
+            System.err.println("批量 Upsert 失败: " + response.getException().getMessage());
+        } else {
+            System.out.println("成功批量 Upsert 数据!");
         }
-        processedRdd.coalesce(repartition).saveAsTextFile(vecOutputPath, GzipCodec.class);
     }
 
 }