|
@@ -7,6 +7,9 @@ import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
|
|
|
import org.apache.commons.lang.math.NumberUtils;
|
|
|
import org.apache.spark.api.java.JavaRDD;
|
|
|
import org.apache.spark.api.java.JavaSparkContext;
|
|
|
+import org.apache.spark.ml.feature.VectorAssembler;
|
|
|
+import org.apache.spark.mllib.linalg.SparseVector;
|
|
|
+import org.apache.spark.mllib.linalg.VectorUDT;
|
|
|
import org.apache.spark.sql.Dataset;
|
|
|
import org.apache.spark.sql.Row;
|
|
|
import org.apache.spark.sql.RowFactory;
|
|
@@ -39,21 +42,28 @@ public class XGBoostTrain {
|
|
|
JavaRDD<Row> rowRDD = rdd.map(s -> {
|
|
|
String[] line = StringUtils.split("\t");
|
|
|
int label = NumberUtils.toInt(line[0]);
|
|
|
+ int[] indices = new int[line.length - 1];
|
|
|
double[] values = new double[line.length - 1];
|
|
|
for (int i = 1; i < line.length; i++) {
|
|
|
String[] fv = StringUtils.split(":");
|
|
|
+ indices[i - 1] = i - 1;
|
|
|
values[i - 1] = NumberUtils.toDouble(fv[1], 0.0);
|
|
|
}
|
|
|
- return RowFactory.create(label, values);
|
|
|
+ SparseVector vector = new SparseVector(indices.length, indices, values);
|
|
|
+ return RowFactory.create(label, vector);
|
|
|
});
|
|
|
+
|
|
|
log.info("rowRDD count {}", rowRDD.count());
|
|
|
// 将 JavaRDD<Row> 转换为 Dataset<Row>
|
|
|
List<StructField> fields = new ArrayList<>();
|
|
|
fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
|
|
|
- fields.add(DataTypes.createStructField("features", new ArrayType(DataTypes.DoubleType, true), true));
|
|
|
+ fields.add(DataTypes.createStructField("features", new VectorUDT(), true));
|
|
|
StructType schema = DataTypes.createStructType(fields);
|
|
|
Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
|
|
|
|
|
|
+
|
|
|
+
|
|
|
+ // 使用 VectorAssembler 转换数据
|
|
|
// 划分训练集和测试集
|
|
|
Dataset<Row>[] splits = dataset.randomSplit(new double[]{0.7, 0.3});
|
|
|
Dataset<Row> trainData = splits[0];
|