|
@@ -29,7 +29,7 @@ public class XGBoostTrain {
|
|
|
try {
|
|
|
SparkSession spark = SparkSession.builder()
|
|
|
.appName("XGBoostTrain")
|
|
|
- .master("local")
|
|
|
+ //.master("local")
|
|
|
.getOrCreate();
|
|
|
|
|
|
JavaSparkContext jsc = new JavaSparkContext(spark.sparkContext());
|
|
@@ -51,7 +51,7 @@ public class XGBoostTrain {
|
|
|
return RowFactory.create(label, vector);
|
|
|
});
|
|
|
|
|
|
- log.info("rowRDD count {}", rowRDD.count());
|
|
|
+ // log.info("rowRDD count {}", rowRDD.count());
|
|
|
// 将 JavaRDD<Row> 转换为 Dataset<Row>
|
|
|
List<StructField> fields = new ArrayList<>();
|
|
|
fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
|
|
@@ -61,7 +61,6 @@ public class XGBoostTrain {
|
|
|
|
|
|
|
|
|
|
|
|
- // 使用 VectorAssembler 转换数据
|
|
|
// 划分训练集和测试集
|
|
|
Dataset<Row>[] splits = dataset.randomSplit(new double[]{0.7, 0.3});
|
|
|
Dataset<Row> trainData = splits[0];
|