|
@@ -49,7 +49,7 @@ public class XGBoostTrain {
|
|
|
log.info("rowRDD count {}", rowRDD.count());
|
|
|
// 将 JavaRDD<Row> 转换为 Dataset<Row>
|
|
|
List<StructField> fields = new ArrayList<>();
|
|
|
- fields.add(DataTypes.createStructField("label", DataTypes.StringType, true));
|
|
|
+ fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
|
|
|
fields.add(DataTypes.createStructField("features", new ArrayType(DataTypes.DoubleType, true), true));
|
|
|
StructType schema = DataTypes.createStructType(fields);
|
|
|
Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
|