|
@@ -11,6 +11,13 @@ import org.apache.spark.sql.Dataset;
|
|
|
import org.apache.spark.sql.Row;
|
|
|
import org.apache.spark.sql.RowFactory;
|
|
|
import org.apache.spark.sql.SparkSession;
|
|
|
+import org.apache.spark.sql.types.ArrayType;
|
|
|
+import org.apache.spark.sql.types.DataTypes;
|
|
|
+import org.apache.spark.sql.types.StructField;
|
|
|
+import org.apache.spark.sql.types.StructType;
|
|
|
+
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.List;
|
|
|
|
|
|
/**
|
|
|
* @author dyp
|
|
@@ -41,10 +48,11 @@ public class XGBoostTrain {
|
|
|
});
|
|
|
log.info("rowRDD count {}", rowRDD.count());
|
|
|
// 将 JavaRDD<Row> 转换为 Dataset<Row>
|
|
|
- Dataset<Row> dataset = spark.createDataFrame(rowRDD, Row.class);
|
|
|
-
|
|
|
- // 如果需要,可以添加列名
|
|
|
- dataset = dataset.toDF("label", "features");
|
|
|
+ List<StructField> fields = new ArrayList<>();
|
|
|
+ fields.add(DataTypes.createStructField("label", DataTypes.StringType, true));
|
|
|
+ fields.add(DataTypes.createStructField("features", new ArrayType(DataTypes.DoubleType, true), true));
|
|
|
+ StructType schema = DataTypes.createStructType(fields);
|
|
|
+ Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
|
|
|
|
|
|
// 划分训练集和测试集
|
|
|
Dataset<Row>[] splits = dataset.randomSplit(new double[]{0.7, 0.3});
|