Quellcode durchsuchen

Merge branch 'main'

# Conflicts:
#	recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrain.java
zhangbo vor 8 Monaten
Ursprung
Commit
72f6f710ce

+ 0 - 5
recommend-model-produce/pom.xml

@@ -167,11 +167,6 @@
             <artifactId>guava</artifactId>
             <version>14.0.1</version>
         </dependency>
-<!--        <dependency>-->
-<!--            <groupId>io.netty</groupId>-->
-<!--            <artifactId>netty-all</artifactId>-->
-<!--            <version>4.1.17.Final</version>-->
-<!--        </dependency>-->
         <dependency>
             <groupId>org.scala-lang</groupId>
             <artifactId>scala-library</artifactId>

+ 54 - 23
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/service/OSSService.java

@@ -2,13 +2,13 @@ package com.tzld.piaoquan.recommend.model.produce.service;
 
 import com.aliyun.oss.OSS;
 import com.aliyun.oss.OSSClientBuilder;
-import com.aliyun.oss.model.CopyObjectRequest;
-import com.aliyun.oss.model.CopyObjectResult;
-import com.aliyun.oss.model.ObjectMetadata;
+import com.aliyun.oss.model.GetObjectRequest;
+import com.aliyun.oss.model.PutObjectRequest;
+import com.tzld.piaoquan.recommend.model.produce.util.CompressUtil;
 import lombok.extern.slf4j.Slf4j;
 
+import java.io.File;
 import java.io.Serializable;
-import java.util.List;
 
 /**
  * @author dyp
@@ -17,24 +17,55 @@ import java.util.List;
 public class OSSService implements Serializable {
     private String accessId = "LTAI5tHMkNaRhpiDB1yWMZPn";
     private String accessKey = "XLi5YUJusVwbbQOaGeGsaRJ1Qyzbui";
-    private String endpoint = "https://oss-cn-hangzhou-internal.aliyuncs.com";
-
-    public void upload(String bucketName, String srcPath, String orcPath) {
-//        OSS ossClient = new OSSClientBuilder().build(endpoint, accessId, accessKey);
-//        try {
-//            if (objectName.startsWith("http")) {
-//                continue;
-//            }
-//            CopyObjectRequest request = new CopyObjectRequest(bucketName, objectName, bucketName, objectName);
-//            ObjectMetadata objectMetadata = new ObjectMetadata();
-//            objectMetadata.setHeader("x-oss-storage-class", "DeepColdArchive");
-//            request.setNewObjectMetadata(objectMetadata);
-//            CopyObjectResult result = ossClient.copyObject(request);
-//        } catch (Exception e) {
-//            log.error("transToDeepColdArchive error {} {}", objectName, e.getMessage(), e);
-//        }
-//        if (ossClient != null) {
-//            ossClient.shutdown();
-//        }
+    //private String endpoint = "https://oss-cn-hangzhou-internal.aliyuncs.com";
+    private String endpoint = "https://oss-cn-hangzhou.aliyuncs.com";
+
+    public void upload(String bucketName, String localFile, String objectName) {
+        OSS ossClient = new OSSClientBuilder().build(endpoint, accessId, accessKey);
+        try {
+            PutObjectRequest request = new PutObjectRequest(bucketName, objectName, new File(localFile));
+            ossClient.putObject(request);
+        } catch (Exception e) {
+            log.error("upload error bucketName {}, localFile {}, objectName {}", bucketName, localFile, objectName, e);
+        }
+        if (ossClient != null) {
+            ossClient.shutdown();
+        }
+    }
+
+    public void download(String bucketName, String localFile, String objectName) {
+        OSS ossClient = new OSSClientBuilder().build(endpoint, accessId, accessKey);
+        try {
+            GetObjectRequest request = new GetObjectRequest(bucketName, objectName);
+            ossClient.getObject(request, new File(localFile));
+            System.out.println("");
+        } catch (Exception e) {
+            log.error("download error bucketName {}, localFile {}, objectName {}", bucketName, localFile, objectName, e);
+        }
+        if (ossClient != null) {
+            ossClient.shutdown();
+        }
+    }
+
+    public static void main(String[] args) {
+
+        String bucketName = "art-test-video";
+        String objectName = "test/model.tar.gz";
+        OSSService ossService = new OSSService();
+
+
+//        String inputPath = "/Users/dingyunpeng/Desktop/model";
+//        String outputPath = "/Users/dingyunpeng/Desktop/model.tar.gz";
+//        CompressUtil.compressDirectoryToGzip(inputPath, outputPath);
+//
+//        String ossPath = "test/model.tar.gz";
+//        ossService.upload(bucketName, outputPath, ossPath);
+
+
+        String destPath = "/Users/dingyunpeng/Desktop/model2.tar.gz";
+        ossService.download(bucketName, destPath, objectName);
+        String destDir = "/Users/dingyunpeng/Desktop/model2";
+        CompressUtil.decompressGzFile(destPath, destDir);
+
     }
 }

+ 184 - 0
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/service/XGBoostService.java

@@ -0,0 +1,184 @@
+package com.tzld.piaoquan.recommend.model.produce.service;
+
+import com.tzld.piaoquan.recommend.model.produce.util.CompressUtil;
+import lombok.extern.slf4j.Slf4j;
+import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
+import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
+import org.apache.commons.lang.math.NumberUtils;
+import org.apache.commons.lang3.RandomUtils;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
+import org.apache.spark.ml.feature.VectorAssembler;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * @author dyp
+ */
+@Slf4j
+public class XGBoostService {
+
+
+    public void train(String[] args) {
+        try {
+            Dataset<Row> assembledData = dataset("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00099.gz");
+            log.info("训练样本 show");
+            assembledData.show();
+            // 创建 XGBoostClassifier 对象
+            XGBoostClassifier xgbClassifier = new XGBoostClassifier()
+                    .setEta(0.01f)
+                    .setSubsample(0.8)
+                    .setColsampleBytree(0.8)
+                    .setScalePosWeight(1)
+                    .setSeed(2024)
+                    .setMissing(0.0f)
+                    .setFeaturesCol("features")
+                    .setLabelCol("label")
+                    .setMaxDepth(5)
+                    .setObjective("binary:logistic")
+                    .setNthread(1)
+                    .setNumRound(100)
+                    .setNumWorkers(1);
+
+
+            // 训练模型
+            XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
+
+            // 保存模型
+            String path = "/root/recommend-model/modeltrain";
+            model.write().overwrite().save("file://" + path);
+            String outputPath = "/root/recommend-model/model.tar.gz";
+            CompressUtil.compressDirectoryToGzip(path, outputPath);
+            String bucketName = "art-test-video";
+            String ossPath = "test/model.tar.gz";
+            OSSService ossService = new OSSService();
+            ossService.upload(bucketName, outputPath, ossPath);
+
+        } catch (Throwable e) {
+            log.error("", e);
+        }
+    }
+
+    public void predict(String[] args) {
+        try {
+
+            Dataset<Row> assembledData = dataset("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00098.gz");
+            log.info("测试样本 show");
+            assembledData.show();
+
+            // 保存模型
+            String bucketName = "art-test-video";
+            String objectName = "test/model.tar.gz";
+            OSSService ossService = new OSSService();
+
+            String destPath = "/root/recommend-model/model2.tar.gz";
+            ossService.download(bucketName, destPath, objectName);
+            String destDir = "/root/recommend-model/modelpredict";
+            CompressUtil.decompressGzFile(destPath, destDir);
+
+            // 显示预测结果
+            XGBoostClassificationModel model2 = XGBoostClassificationModel.load("file://" + destDir);
+            Dataset<Row> predictions = model2.transform(assembledData);
+            predictions.select("label", "prediction", "rawPrediction", "probability", "features").show(500);
+
+            // 计算AUC
+            Dataset<Row> selected = predictions.select("label", "rawPrediction");
+            BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
+                    .setLabelCol("label")
+                    .setRawPredictionCol("rawPrediction")
+                    .setMetricName("areaUnderROC");
+            double auc = evaluator.evaluate(selected);
+            log.info("AUC: {}", auc);
+
+        } catch (Throwable e) {
+            log.error("", e);
+        }
+    }
+
+    private static Dataset<Row> dataset(String path) {
+        String[] features = {"cpa",
+                "b2_1h_ctr",
+                "b2_1h_ctcvr",
+                "b2_1h_cvr",
+                "b2_1h_conver",
+                "b2_1h_click",
+                "b2_1h_conver*log(view)",
+                "b2_1h_conver*ctcvr",
+                "b2_2h_ctr",
+                "b2_2h_ctcvr",
+                "b2_2h_cvr",
+                "b2_2h_conver",
+                "b2_2h_click",
+                "b2_2h_conver*log(view)",
+                "b2_2h_conver*ctcvr",
+                "b2_3h_ctr",
+                "b2_3h_ctcvr",
+                "b2_3h_cvr",
+                "b2_3h_conver",
+                "b2_3h_click",
+                "b2_3h_conver*log(view)",
+                "b2_3h_conver*ctcvr",
+                "b2_6h_ctr",
+                "b2_6h_ctcvr"
+        };
+
+
+        SparkSession spark = SparkSession.builder()
+                .appName("XGBoostTrain")
+                .master("local")
+                .getOrCreate();
+
+        JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
+        String file = path;
+        JavaRDD<String> rdd = jsc.textFile(file);
+
+        JavaRDD<Row> rowRDD = rdd.map(s -> {
+            String[] line = StringUtils.split(s, '\t');
+            int label = NumberUtils.toInt(line[0]);
+            // 选特征
+            Map<String, Double> map = new HashMap<>();
+            for (int i = 1; i < line.length; i++) {
+                String[] fv = StringUtils.split(line[i], ':');
+                map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
+            }
+
+            Object[] v = new Object[features.length + 1];
+            v[0] = label;
+            for (int i = 0; i < features.length; i++) {
+                v[i + 1] = map.getOrDefault(features[i], 0.0d);
+            }
+
+            return RowFactory.create(v);
+        });
+
+        log.info("rowRDD count {}", rowRDD.count());
+        // 将 JavaRDD<Row> 转换为 Dataset<Row>
+        List<StructField> fields = new ArrayList<>();
+        fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
+        for (String f : features) {
+            fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
+        }
+        StructType schema = DataTypes.createStructType(fields);
+        Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
+
+        VectorAssembler assembler = new VectorAssembler()
+                .setInputCols(features)
+                .setOutputCol("features");
+
+        Dataset<Row> assembledData = assembler.transform(dataset);
+        assembledData.show();
+        return assembledData;
+    }
+}

+ 89 - 0
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/util/CompressUtil.java

@@ -0,0 +1,89 @@
+package com.tzld.piaoquan.recommend.model.produce.util;
+
+import lombok.extern.slf4j.Slf4j;
+import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
+import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
+import org.apache.commons.compress.archivers.tar.TarArchiveOutputStream;
+import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream;
+import org.apache.commons.compress.compressors.gzip.GzipCompressorOutputStream;
+
+import java.io.*;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+
+/**
+ * @author dyp
+ */
+@Slf4j
+public class CompressUtil {
+    public static void compressDirectoryToGzip(String sourceDirPath, String outputFilePath) {
+        // 创建.gz文件的输出流
+        try (OutputStream out = new FileOutputStream(outputFilePath);
+             GzipCompressorOutputStream gzipOut = new GzipCompressorOutputStream(out);
+             TarArchiveOutputStream taos = new TarArchiveOutputStream(gzipOut)) {
+
+            taos.setLongFileMode(TarArchiveOutputStream.LONGFILE_GNU);
+
+            // 遍历目录
+            Files.walk(Paths.get(sourceDirPath))
+                    .filter(Files::isRegularFile)
+                    .forEach(filePath -> {
+                        try {
+                            // 为每个文件创建TarEntry
+                            TarArchiveEntry entry = new TarArchiveEntry(filePath.toFile(), filePath.toString().substring(sourceDirPath.length() + 1));
+                            taos.putArchiveEntry(entry);
+
+                            // 读取文件内容并写入TarArchiveOutputStream
+                            try (InputStream is = Files.newInputStream(filePath)) {
+                                byte[] buffer = new byte[1024];
+                                int len;
+                                while ((len = is.read(buffer)) > 0) {
+                                    taos.write(buffer, 0, len);
+                                }
+                            }
+                            // 关闭entry
+                            taos.closeArchiveEntry();
+                        } catch (IOException e) {
+                            log.error("", e);
+                        }
+                    });
+        } catch (Exception e) {
+            log.error("", e);
+        }
+    }
+
+    public static void decompressGzFile(String gzipFilePath, String destDirPath) {
+        try (InputStream gzipIn = new FileInputStream(gzipFilePath);
+             GzipCompressorInputStream gzIn = new GzipCompressorInputStream(gzipIn);
+             TarArchiveInputStream tais = new TarArchiveInputStream(gzIn)) {
+
+            TarArchiveEntry entry;
+            Files.createDirectories(Paths.get(destDirPath));
+            while ((entry = tais.getNextTarEntry()) != null) {
+                if (entry.isDirectory()) {
+                    // 如果是目录,创建目录
+                    Files.createDirectories(Paths.get(destDirPath, entry.getName()));
+                } else {
+                    // 如果是文件,创建文件并写入内容
+                    File outputFile = new File(destDirPath, entry.getName());
+                    if (!outputFile.exists()) {
+                        File parent = outputFile.getParentFile();
+                        if (!parent.exists()) {
+                            parent.mkdirs();
+                        }
+                        outputFile.createNewFile();
+                    }
+                    try (OutputStream out = new FileOutputStream(outputFile)) {
+                        byte[] buffer = new byte[1024];
+                        int len;
+                        while ((len = tais.read(buffer)) > 0) {
+                            out.write(buffer, 0, len);
+                        }
+                    }
+                }
+            }
+        } catch (Exception e) {
+            log.error("", e);
+        }
+    }
+}

+ 16 - 0
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostPredict.java

@@ -0,0 +1,16 @@
+package com.tzld.piaoquan.recommend.model.produce.xgboost;
+
+import com.tzld.piaoquan.recommend.model.produce.service.XGBoostService;
+import lombok.extern.slf4j.Slf4j;
+
+/**
+ * @author dyp
+ */
+@Slf4j
+public class XGBoostPredict {
+
+    public static void main(String[] args) {
+        XGBoostService xgb = new XGBoostService();
+        xgb.predict(args);
+    }
+}

+ 3 - 120
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrain.java

@@ -1,26 +1,7 @@
 package com.tzld.piaoquan.recommend.model.produce.xgboost;
 
+import com.tzld.piaoquan.recommend.model.produce.service.XGBoostService;
 import lombok.extern.slf4j.Slf4j;
-import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
-import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
-import org.apache.commons.lang.math.NumberUtils;
-import org.apache.commons.lang3.RandomUtils;
-import org.apache.commons.lang3.StringUtils;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.feature.VectorAssembler;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
-import org.apache.spark.sql.types.DataTypes;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
-
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
 
 /**
  * @author dyp
@@ -29,105 +10,7 @@ import java.util.Map;
 public class XGBoostTrain {
 
     public static void main(String[] args) {
-        try {
-
-            String[] features = {"cpa",
-                    "b2_12h_ctr",
-                    "b2_12h_ctcvr",
-                    "b2_12h_cvr",
-                    "b2_12h_conver",
-                    "b2_12h_click",
-                    "b2_12h_conver*log(view)",
-                    "b2_12h_conver*ctcvr",
-                    "b2_7d_ctr",
-                    "b2_7d_ctcvr",
-                    "b2_7d_cvr",
-                    "b2_7d_conver",
-                    "b2_7d_click",
-                    "b2_7d_conver*log(view)",
-                    "b2_7d_conver*ctcvr"
-            };
-
-
-            SparkSession spark = SparkSession.builder()
-                    .appName("XGBoostTrain")
-                    .master("local")
-                    .getOrCreate();
-
-            JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
-            String file = "/dw/recommend/model/33_ad_train_data_v4/20240726/part-00099.gz";
-            //file = "/Users/dingyunpeng/Desktop/part-00099.gz";
-            JavaRDD<String> rdd = jsc.textFile(file);
-
-            JavaRDD<Row> rowRDD = rdd.map(s -> {
-                String[] line = StringUtils.split(s, '\t');
-                int label = NumberUtils.toInt(line[0]);
-                // 选特征
-                Map<String, Double> map = new HashMap<>();
-                for (int i = 1; i < line.length; i++) {
-                    String[] fv = StringUtils.split(line[i], ':');
-                    map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
-                }
-
-                Object[] v = new Object[features.length + 1];
-                v[0] = label;
-//                v[0] = RandomUtils.nextInt(0, 2);
-                for (int i = 0; i < features.length; i++) {
-                    v[i + 1] = map.getOrDefault(features[i], 0.0d);
-                }
-
-                return RowFactory.create(v);
-            });
-
-            log.info("rowRDD count {}", rowRDD.count());
-            // 将 JavaRDD<Row> 转换为 Dataset<Row>
-            List<StructField> fields = new ArrayList<>();
-            fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
-            for (String f : features) {
-                fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
-            }
-            StructType schema = DataTypes.createStructType(fields);
-            Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
-
-            VectorAssembler assembler = new VectorAssembler()
-                    .setInputCols(features)
-                    .setOutputCol("features");
-
-            Dataset<Row> assembledData = assembler.transform(dataset);
-            assembledData.show();
-            // 划分训练集和测试集
-            Dataset<Row>[] splits = assembledData.randomSplit(new double[]{0.7, 0.3});
-            Dataset<Row> trainData = splits[0];
-            trainData.show(500);
-            Dataset<Row> testData = splits[1];
-            testData.show(500);
-
-            // 参数
-
-
-            // 创建 XGBoostClassifier 对象
-            XGBoostClassifier xgbClassifier = new XGBoostClassifier()
-                    .setEta(0.1f)
-                    .setMissing(0.0f)
-                    .setFeaturesCol("features")
-                    .setLabelCol("label")
-                    .setMaxDepth(5)
-                    .setObjective("binary:logistic")
-                    .setNthread(1)
-                    .setNumRound(5)
-                    .setNumWorkers(1);
-
-
-            // 训练模型
-            XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
-
-            // 显示预测结果
-            Dataset<Row> predictions = model.transform(assembledData);
-            predictions.select("label", "prediction", "features", "rawPrediction", "probability").show(500);
-
-
-        } catch (Throwable e) {
-            log.error("", e);
-        }
+        XGBoostService xgb = new XGBoostService();
+        xgb.train(args);
     }
 }

+ 86 - 99
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrainLocalTest.java

@@ -4,10 +4,10 @@ import lombok.extern.slf4j.Slf4j;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
 import org.apache.commons.lang.math.NumberUtils;
-import org.apache.commons.lang3.RandomUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
 import org.apache.spark.ml.feature.VectorAssembler;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
@@ -31,111 +31,22 @@ public class XGBoostTrainLocalTest {
     public static void main(String[] args) {
         try {
 
-            String[] features = {"cpa",
-                    "b2_12h_ctr",
-                    "b2_12h_ctcvr",
-                    "b2_12h_cvr",
-                    "b2_12h_conver",
-                    "b2_12h_click",
-                    "b2_12h_conver*log(view)",
-                    "b2_12h_conver*ctcvr",
-                    "b2_7d_ctr",
-                    "b2_7d_ctcvr",
-                    "b2_7d_cvr",
-                    "b2_7d_conver",
-                    "b2_7d_click",
-                    "b2_7d_conver*log(view)",
-                    "b2_7d_conver*ctcvr"
-            };
-
-
-            SparkSession spark = SparkSession.builder()
-                    .appName("XGBoostTrain")
-                    .master("local")
-                    .getOrCreate();
-
-            JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
-            String file = "/dw/recommend/model/33_ad_train_data_v4/20240726/part-00099.gz";
-            file = "/Users/dingyunpeng/Desktop/part-00099.gz";
-            JavaRDD<String> rdd = jsc.textFile(file);
-
-            // 将 RDD[LabeledPoint] 转换为 JavaRDD<Row>
-//            JavaRDD<Row> rowRDD = rdd.map(s -> {
-//                String[] line = StringUtils.split(s, '\t');
-//                int label = NumberUtils.toInt(line[0]);
-//                // 选特征
-//                Map<String, Double> map = new HashMap<>();
-//                for (int i = 1; i < line.length; i++) {
-//                    String[] fv = StringUtils.split(line[i], ':');
-//                    map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
-//                }
-//
-//                int[] indices = new int[features.length];
-//                double[] values = new double[features.length];
-//                for (int i = 0; i < features.length; i++) {
-//                    indices[i] = i;
-//                    values[i] = map.getOrDefault(features[i], 0.0);
-//                }
-//                SparseVector vector = new SparseVector(indices.length, indices, values);
-//                return RowFactory.create(label, vector);
-//            });
-
-            JavaRDD<Row> rowRDD = rdd.map(s -> {
-                String[] line = StringUtils.split(s, '\t');
-                int label = NumberUtils.toInt(line[0]);
-                // 选特征
-                Map<String, Double> map = new HashMap<>();
-                for (int i = 1; i < line.length; i++) {
-                    String[] fv = StringUtils.split(line[i], ':');
-                    map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
-                }
-
-                Object[] v = new Object[features.length + 1];
-                v[0] = label;
-                v[0] = RandomUtils.nextInt(0, 2);
-                for (int i = 0; i < features.length; i++) {
-                    v[i + 1] = map.getOrDefault(features[i], 0.0d);
-                }
-
-                return RowFactory.create(v);
-            });
-
-            log.info("rowRDD count {}", rowRDD.count());
-            // 将 JavaRDD<Row> 转换为 Dataset<Row>
-            List<StructField> fields = new ArrayList<>();
-            fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
-            for (String f : features) {
-                fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
-            }
-            StructType schema = DataTypes.createStructType(fields);
-            Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
-
-            VectorAssembler assembler = new VectorAssembler()
-                    .setInputCols(features)
-                    .setOutputCol("features");
-
-            Dataset<Row> assembledData = assembler.transform(dataset);
-            assembledData.show();
-            // 划分训练集和测试集
-            Dataset<Row>[] splits = assembledData.randomSplit(new double[]{0.7, 0.3});
-            Dataset<Row> trainData = splits[0];
-            trainData.show(500);
-            Dataset<Row> testData = splits[1];
-            testData.show(500);
-
-            // 参数
-
+            Dataset<Row> assembledData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
 
             // 创建 XGBoostClassifier 对象
             XGBoostClassifier xgbClassifier = new XGBoostClassifier()
-                    .setEta(0.1f)
+                    .setEta(0.01f)
+                    .setSubsample(0.8)
+                    .setColsampleBytree(0.8)
+                    .setScalePosWeight(1)
+                    .setSeed(2024)
                     .setMissing(0.0f)
                     .setFeaturesCol("features")
                     .setLabelCol("label")
                     .setMaxDepth(5)
                     .setObjective("binary:logistic")
                     .setNthread(1)
-                    .setNumRound(5)
+                    .setNumRound(100)
                     .setNumWorkers(1);
 
 
@@ -143,12 +54,88 @@ public class XGBoostTrainLocalTest {
             XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
 
             // 显示预测结果
-            Dataset<Row> predictions = model.transform(assembledData);
-            predictions.select("label", "prediction", "features", "rawPrediction", "probability").show(500);
+            Dataset<Row> predictData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
+            Dataset<Row> predictions = model.transform(predictData);
+            predictions.select("label", "prediction", "rawPrediction", "probability", "features").show();
 
+            // 计算AUC
+            Dataset<Row> selected = predictions.select("label", "rawPrediction");
+            BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
+                    .setLabelCol("label")
+                    .setRawPredictionCol("rawPrediction")
+                    .setMetricName("areaUnderROC");
+            double auc = evaluator.evaluate(selected);
+
+            log.info("AUC: {}", auc);
 
         } catch (Throwable e) {
             log.error("", e);
         }
     }
+
+    private static Dataset<Row> dataset(String path) {
+        String[] features = {"cpa",
+                "b2_12h_ctr",
+                "b2_12h_ctcvr",
+                "b2_12h_cvr",
+                "b2_12h_conver",
+                "b2_12h_click",
+                "b2_12h_conver*log(view)",
+                "b2_12h_conver*ctcvr",
+                "b2_7d_ctr",
+                "b2_7d_ctcvr",
+                "b2_7d_cvr",
+                "b2_7d_conver",
+                "b2_7d_click",
+                "b2_7d_conver*log(view)",
+                "b2_7d_conver*ctcvr"
+        };
+
+
+        SparkSession spark = SparkSession.builder()
+                .appName("XGBoostTrain")
+                .master("local")
+                .getOrCreate();
+
+        JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
+        String file = path;
+        JavaRDD<String> rdd = jsc.textFile(file);
+
+        JavaRDD<Row> rowRDD = rdd.map(s -> {
+            String[] line = StringUtils.split(s, '\t');
+            int label = NumberUtils.toInt(line[0]);
+            // 选特征
+            Map<String, Double> map = new HashMap<>();
+            for (int i = 1; i < line.length; i++) {
+                String[] fv = StringUtils.split(line[i], ':');
+                map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
+            }
+
+            Object[] v = new Object[features.length + 1];
+            v[0] = label;
+            for (int i = 0; i < features.length; i++) {
+                v[i + 1] = map.getOrDefault(features[i], 0.0d);
+            }
+
+            return RowFactory.create(v);
+        });
+
+        log.info("rowRDD count {}", rowRDD.count());
+        // 将 JavaRDD<Row> 转换为 Dataset<Row>
+        List<StructField> fields = new ArrayList<>();
+        fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
+        for (String f : features) {
+            fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
+        }
+        StructType schema = DataTypes.createStructType(fields);
+        Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
+
+        VectorAssembler assembler = new VectorAssembler()
+                .setInputCols(features)
+                .setOutputCol("features");
+
+        Dataset<Row> assembledData = assembler.transform(dataset);
+        assembledData.show();
+        return assembledData;
+    }
 }