丁云鹏 4 maanden geleden
bovenliggende
commit
71266830e0

+ 1 - 1
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/Demo.java

@@ -29,7 +29,7 @@ public class Demo {
         final MilvusClientV2 milvusClientV2 = new MilvusClientV2(ConnectConfig.builder()
                 .uri("https://in01-bf9dcd371016170.ali-cn-hangzhou.vectordb.zilliz.com.cn:19530")
                 .token("423a29de63a907e6662b9493c4f95caf799f64f8701cc70db930bb6da7f05914e6ed2374342dc438a8b9d37da0bf164c8ee531bd")
-                .secure(false)
+                .secure(true)
                 .connectTimeoutMs(5000L)
                 .build());
         // Check if the collection exists

+ 143 - 0
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/Demo2.java

@@ -0,0 +1,143 @@
+package com.tzld.piaoquan.recommend.model.produce;
+
+import com.alibaba.fastjson.JSONObject;
+import com.tzld.piaoquan.recommend.model.produce.service.CMDService;
+import com.tzld.piaoquan.recommend.model.produce.service.HDFSService;
+import io.milvus.client.MilvusClient;
+import io.milvus.client.MilvusServiceClient;
+import io.milvus.grpc.MutationResult;
+import io.milvus.param.ConnectParam;
+import io.milvus.param.R;
+import io.milvus.param.RetryParam;
+import io.milvus.param.dml.UpsertParam;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.lang.math.NumberUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * @author dyp
+ */
+@Slf4j
+public class Demo2 {
+
+    private static HDFSService hdfsService = new HDFSService();
+
+
+    public static void main(String[] args) throws IOException {
+        CMDService cmd = new CMDService();
+        Map<String, String> argMap = cmd.parse(args);
+        String file = argMap.get("dataPath");
+        String milvusUrl = argMap.get("milvusUrl");
+        String milvusToken = argMap.get("milvusToken");
+        String milvusCollection = argMap.get("milvusCollection");
+        int batchSize = NumberUtils.toInt(argMap.get("batchSize"), 5000);
+
+        ConnectParam connectParamv1 = ConnectParam.newBuilder()
+                .withUri(milvusUrl)
+                .withToken(milvusToken)
+                .build();
+        RetryParam retryParamv1 = RetryParam.newBuilder()
+                .withMaxRetryTimes(3)
+                .build();
+        System.out.println(connectParamv1);
+        MilvusClient milvusClientv1 = new MilvusServiceClient(connectParamv1).withRetry(retryParamv1);
+
+        SparkSession spark = SparkSession.builder()
+                .appName("I2IMilvusDataImport")
+                .getOrCreate();
+
+        JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
+        JavaRDD<String> rdd = jsc.textFile(file);
+
+        // 定义处理数据的函数
+        String finalMilvusUrl = milvusUrl;
+        String finalMilvusToken = milvusToken;
+        rdd.foreachPartition(lines -> {
+            ConnectParam connectParam = ConnectParam.newBuilder()
+                    .withUri(finalMilvusUrl)
+                    .withToken(finalMilvusToken)
+                    //.withAuthorization("emr", "Qingqu@2024")
+                    //.withConnectTimeout(60L, TimeUnit.SECONDS)
+                    .build();
+            RetryParam retryParam = RetryParam.newBuilder()
+                    .withMaxRetryTimes(3)
+                    .build();
+            System.out.println(connectParam);
+            MilvusClient milvusClient = new MilvusServiceClient(connectParam).withRetry(retryParam);
+            List<String> batch = new ArrayList<>();
+            while (lines.hasNext()) {
+                String line = lines.next();
+                batch.add(line);
+                // 如果批量数据达到指定大小,则插入到Milvus
+                if (batch.size() >= batchSize) {
+                    upsertToMilvus(milvusClient, batch, milvusCollection);
+                    batch.clear();
+                }
+            }
+
+            // 插入剩余的部分
+            if (!batch.isEmpty()) {
+                upsertToMilvus(milvusClient, batch, milvusCollection);
+            }
+        });
+    }
+
+    // 将数据批量 Upsert 到 Milvus
+    private static void upsertToMilvus(MilvusClient milvusClient, List<String> batch, String milvusCollection) {
+        List<UpsertParam.Field> fields = new ArrayList<>();
+        List<Long> ids = new ArrayList<>();
+        List<List<Float>> vectors = new ArrayList<>();
+
+        // 处理每条记录,假设每条记录包含 id 和 vector 字段,其他字段为动态列
+        Map<String, List<Object>> dynamicFields = new HashMap<>();
+
+        for (String record : batch) {
+            // 获取 ID 和向量
+            String[] values = StringUtils.split(record, "\t");
+            ids.add(Long.valueOf(values[0]));
+            vectors.add(JSONObject.parseArray(values[1], Float.class));
+
+//            // 遍历剩余字段(动态列)
+//            for (Map.Entry<String, Object> entry : record.entrySet()) {
+//                String key = entry.getKey();
+//                if (!key.equals("id") && !key.equals("vec")) {
+//                    // 如果动态字段不存在,则初始化
+//                    dynamicFields.computeIfAbsent(key, k -> new ArrayList<>()).add(entry.getValue());
+//                }
+//            }
+        }
+//
+//        // 将动态字段添加到 Milvus 插入字段中
+//        for (Map.Entry<String, List<Object>> dynamicEntry : dynamicFields.entrySet()) {
+//            fields.add(new UpsertParam.Field(dynamicEntry.getKey(), dynamicEntry.getValue()));
+//        }
+
+        // 将 id 和 vector 添加到字段中
+        fields.add(new UpsertParam.Field("vid", ids));
+        fields.add(new UpsertParam.Field("vec", vectors));
+
+        // 创建插入参数,设置 upsert 为 true
+        UpsertParam upsertParam = UpsertParam.newBuilder()
+                .withCollectionName(milvusCollection)
+                .withFields(fields)
+                .build();
+
+        // 执行 upsert
+        R<MutationResult> response = milvusClient.upsert(upsertParam);
+        if (response.getStatus().intValue() != 0) {
+            System.err.println("批量 Upsert 失败: " + response.getException().getMessage());
+        } else {
+            System.out.println("成功批量 Upsert 数据!");
+        }
+    }
+
+}