丁云鹏 8 months ago
parent
commit
e514663d53

+ 2 - 3
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrain.java

@@ -71,9 +71,6 @@ public class XGBoostTrain {
             JavaRDD<Row> rowRDD = rdd.map(s -> {
                 String[] line = StringUtils.split("\t");
                 int label = NumberUtils.toInt(line[0]);
-                int[] indices = new int[line.length - 1];
-                double[] values = new double[line.length - 1];
-
                 // 选特征
                 Map<String, Double> map = new HashMap<>();
                 for (int i = 1; i < line.length; i++) {
@@ -85,6 +82,8 @@ public class XGBoostTrain {
                     values[i] = map.getOrDefault(features.get(i), 0.0);
                 }
 
+                int[] indices = new int[features.length];
+                double[] values = new double[features.length];
                 SparseVector vector = new SparseVector(indices.length, indices, values);
                 return RowFactory.create(label, vector);
             });