丁云鹏 hai 8 meses
pai
achega
9eb6b20e12

+ 6 - 2
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrain.java

@@ -108,14 +108,18 @@ public class XGBoostTrain {
 
             // 创建 XGBoostClassifier 对象
             XGBoostClassifier xgbClassifier = new XGBoostClassifier()
-                    .setEta(0.1f)
+                    .setEta(0.01f)
+                    .setSubsample(0.8)
+                    .setColsampleBytree(0.8)
+                    .setScalePosWeight(1)
+                    .setSeed(2024)
                     .setMissing(0.0f)
                     .setFeaturesCol("features")
                     .setLabelCol("label")
                     .setMaxDepth(5)
                     .setObjective("binary:logistic")
                     .setNthread(1)
-                    .setNumRound(5)
+                    .setNumRound(100)
                     .setNumWorkers(1);
 
 

+ 85 - 89
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrainLocalTest.java

@@ -10,6 +10,7 @@ import org.apache.commons.lang3.RandomUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
 import org.apache.spark.ml.feature.VectorAssembler;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
@@ -33,116 +34,111 @@ public class XGBoostTrainLocalTest {
     public static void main(String[] args) {
         try {
 
-            String[] features = {"cpa",
-                    "b2_12h_ctr",
-                    "b2_12h_ctcvr",
-                    "b2_12h_cvr",
-                    "b2_12h_conver",
-                    "b2_12h_click",
-                    "b2_12h_conver*log(view)",
-                    "b2_12h_conver*ctcvr",
-                    "b2_7d_ctr",
-                    "b2_7d_ctcvr",
-                    "b2_7d_cvr",
-                    "b2_7d_conver",
-                    "b2_7d_click",
-                    "b2_7d_conver*log(view)",
-                    "b2_7d_conver*ctcvr"
-            };
-
-
-            SparkSession spark = SparkSession.builder()
-                    .appName("XGBoostTrain")
-                    .master("local")
-                    .getOrCreate();
-
-            JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
-            String file = "/dw/recommend/model/33_ad_train_data_v4/20240726/part-00099.gz";
-            file = "/Users/dingyunpeng/Desktop/part-00099.gz";
-            JavaRDD<String> rdd = jsc.textFile(file);
-
-            JavaRDD<Row> rowRDD = rdd.map(s -> {
-                String[] line = StringUtils.split(s, '\t');
-                int label = NumberUtils.toInt(line[0]);
-                // 选特征
-                Map<String, Double> map = new HashMap<>();
-                for (int i = 1; i < line.length; i++) {
-                    String[] fv = StringUtils.split(line[i], ':');
-                    map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
-                }
-
-                Object[] v = new Object[features.length + 1];
-                v[0] = label;
-                v[0] = RandomUtils.nextInt(0, 2);
-                for (int i = 0; i < features.length; i++) {
-                    v[i + 1] = map.getOrDefault(features[i], 0.0d);
-                }
-
-                return RowFactory.create(v);
-            });
-
-            log.info("rowRDD count {}", rowRDD.count());
-            // 将 JavaRDD<Row> 转换为 Dataset<Row>
-            List<StructField> fields = new ArrayList<>();
-            fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
-            for (String f : features) {
-                fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
-            }
-            StructType schema = DataTypes.createStructType(fields);
-            Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
-
-            VectorAssembler assembler = new VectorAssembler()
-                    .setInputCols(features)
-                    .setOutputCol("features");
-
-            Dataset<Row> assembledData = assembler.transform(dataset);
-            assembledData.show();
-            // 划分训练集和测试集
-            Dataset<Row>[] splits = assembledData.randomSplit(new double[]{0.7, 0.3});
-            Dataset<Row> trainData = splits[0];
-            trainData.show(500);
-            Dataset<Row> testData = splits[1];
-            testData.show(500);
-
-            // 参数
-
+            Dataset<Row> assembledData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
 
             // 创建 XGBoostClassifier 对象
             XGBoostClassifier xgbClassifier = new XGBoostClassifier()
-                    .setEta(0.1f)
+                    .setEta(0.01f)
+                    .setSubsample(0.8)
+                    .setColsampleBytree(0.8)
+                    .setScalePosWeight(1)
+                    .setSeed(2024)
                     .setMissing(0.0f)
                     .setFeaturesCol("features")
                     .setLabelCol("label")
                     .setMaxDepth(5)
                     .setObjective("binary:logistic")
                     .setNthread(1)
-                    .setNumRound(5)
+                    .setNumRound(100)
                     .setNumWorkers(1);
 
 
             // 训练模型
             XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
 
-            // 保存模型
-            String path = "/Users/dingyunpeng/Desktop/model";
-            model.write().overwrite().save(path);
-
-            String outputPath = "/Users/dingyunpeng/Desktop/model.tar.gz";
-            CompressUtil.compressDirectoryToGzip(path, outputPath);
-            String bucketName = "art-test-video";
-            String ossPath = "test/model.tar.gz";
-            OSSService ossService = new OSSService();
-            ossService.upload(bucketName, outputPath, ossPath);
-
             // 显示预测结果
-            XGBoostClassificationModel model2 = XGBoostClassificationModel.load(path);
-            Dataset<Row> predictions = model2.transform(assembledData);
+            Dataset<Row> predictData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
+            Dataset<Row> predictions = model.transform(predictData);
             predictions.select("label", "prediction", "features", "rawPrediction", "probability").show(500);
 
-            spark.close();
+            // 计算AUC
+            Dataset<Row> selected = predictions.select("label", "probability");
+            BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
+                    .setLabelCol("label")
+                    .setRawPredictionCol("probability")
+                    .setMetricName("areaUnderROC");
+            double auc = evaluator.evaluate(selected);
+
+            log.info("AUC: {}", auc);
 
         } catch (Throwable e) {
             log.error("", e);
         }
     }
+
+    private static Dataset<Row> dataset(String path) {
+        String[] features = {"cpa",
+                "b2_12h_ctr",
+                "b2_12h_ctcvr",
+                "b2_12h_cvr",
+                "b2_12h_conver",
+                "b2_12h_click",
+                "b2_12h_conver*log(view)",
+                "b2_12h_conver*ctcvr",
+                "b2_7d_ctr",
+                "b2_7d_ctcvr",
+                "b2_7d_cvr",
+                "b2_7d_conver",
+                "b2_7d_click",
+                "b2_7d_conver*log(view)",
+                "b2_7d_conver*ctcvr"
+        };
+
+
+        SparkSession spark = SparkSession.builder()
+                .appName("XGBoostTrain")
+                .master("local")
+                .getOrCreate();
+
+        JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
+        String file = path;
+        JavaRDD<String> rdd = jsc.textFile(file);
+
+        JavaRDD<Row> rowRDD = rdd.map(s -> {
+            String[] line = StringUtils.split(s, '\t');
+            int label = NumberUtils.toInt(line[0]);
+            // 选特征
+            Map<String, Double> map = new HashMap<>();
+            for (int i = 1; i < line.length; i++) {
+                String[] fv = StringUtils.split(line[i], ':');
+                map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
+            }
+
+            Object[] v = new Object[features.length + 1];
+            v[0] = label;
+            for (int i = 0; i < features.length; i++) {
+                v[i + 1] = map.getOrDefault(features[i], 0.0d);
+            }
+
+            return RowFactory.create(v);
+        });
+
+        log.info("rowRDD count {}", rowRDD.count());
+        // 将 JavaRDD<Row> 转换为 Dataset<Row>
+        List<StructField> fields = new ArrayList<>();
+        fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
+        for (String f : features) {
+            fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
+        }
+        StructType schema = DataTypes.createStructType(fields);
+        Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
+
+        VectorAssembler assembler = new VectorAssembler()
+                .setInputCols(features)
+                .setOutputCol("features");
+
+        Dataset<Row> assembledData = assembler.transform(dataset);
+        assembledData.show();
+        return assembledData;
+    }
 }