丁云鹏 8 месяцев назад
Родитель
Сommit
c0e5db7dbf

+ 5 - 9
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrain.java

@@ -1,17 +1,13 @@
 package com.tzld.piaoquan.recommend.model.produce.xgboost;
 
-import com.google.common.collect.Lists;
 import lombok.extern.slf4j.Slf4j;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
 import org.apache.commons.lang.math.NumberUtils;
-import org.apache.commons.lang3.RandomUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.ml.feature.VectorAssembler;
-import org.apache.spark.ml.linalg.SparseVector;
-import org.apache.spark.ml.linalg.VectorUDT;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
@@ -20,8 +16,10 @@ import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 
-import java.util.*;
-import java.util.stream.Collectors;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
 
 /**
  * @author dyp
@@ -93,10 +91,8 @@ public class XGBoostTrain {
 
                 Object[] v = new Object[features.length + 1];
                 v[0] = label;
-                v[0] = RandomUtils.nextInt(0, 2);
-                double[] values = new double[features.length];
+                //v[0] = RandomUtils.nextInt(0, 2);
                 for (int i = 0; i < features.length; i++) {
-                    values[i] = map.getOrDefault(features[i], 0.0d);
                     v[i + 1] = map.getOrDefault(features[i], 0.0d);
                 }