|
@@ -71,9 +71,6 @@ public class XGBoostTrain {
|
|
|
JavaRDD<Row> rowRDD = rdd.map(s -> {
|
|
|
String[] line = StringUtils.split("\t");
|
|
|
int label = NumberUtils.toInt(line[0]);
|
|
|
- int[] indices = new int[line.length - 1];
|
|
|
- double[] values = new double[line.length - 1];
|
|
|
-
|
|
|
// 选特征
|
|
|
Map<String, Double> map = new HashMap<>();
|
|
|
for (int i = 1; i < line.length; i++) {
|
|
@@ -85,6 +82,8 @@ public class XGBoostTrain {
|
|
|
values[i] = map.getOrDefault(features.get(i), 0.0);
|
|
|
}
|
|
|
|
|
|
+ int[] indices = new int[features.length];
|
|
|
+ double[] values = new double[features.length];
|
|
|
SparseVector vector = new SparseVector(indices.length, indices, values);
|
|
|
return RowFactory.create(label, vector);
|
|
|
});
|