|
@@ -60,7 +60,6 @@ public class XGBoostTrain {
|
|
|
Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
|
|
|
|
|
|
|
|
|
-
|
|
|
// 划分训练集和测试集
|
|
|
Dataset<Row>[] splits = dataset.randomSplit(new double[]{0.7, 0.3});
|
|
|
Dataset<Row> trainData = splits[0];
|
|
@@ -76,7 +75,7 @@ public class XGBoostTrain {
|
|
|
.setLabelCol("label")
|
|
|
.setMaxDepth(5)
|
|
|
.setObjective("binary:logistic")
|
|
|
- .setNthread(4)
|
|
|
+ .setNthread(1)
|
|
|
.setNumRound(10)
|
|
|
.setNumWorkers(2);
|
|
|
|