|
@@ -1,17 +1,13 @@
|
|
package com.tzld.piaoquan.recommend.model.produce.xgboost;
|
|
package com.tzld.piaoquan.recommend.model.produce.xgboost;
|
|
|
|
|
|
-import com.google.common.collect.Lists;
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
|
|
import org.apache.commons.lang.math.NumberUtils;
|
|
import org.apache.commons.lang.math.NumberUtils;
|
|
-import org.apache.commons.lang3.RandomUtils;
|
|
|
|
import org.apache.commons.lang3.StringUtils;
|
|
import org.apache.commons.lang3.StringUtils;
|
|
import org.apache.spark.api.java.JavaRDD;
|
|
import org.apache.spark.api.java.JavaRDD;
|
|
import org.apache.spark.api.java.JavaSparkContext;
|
|
import org.apache.spark.api.java.JavaSparkContext;
|
|
import org.apache.spark.ml.feature.VectorAssembler;
|
|
import org.apache.spark.ml.feature.VectorAssembler;
|
|
-import org.apache.spark.ml.linalg.SparseVector;
|
|
|
|
-import org.apache.spark.ml.linalg.VectorUDT;
|
|
|
|
import org.apache.spark.sql.Dataset;
|
|
import org.apache.spark.sql.Dataset;
|
|
import org.apache.spark.sql.Row;
|
|
import org.apache.spark.sql.Row;
|
|
import org.apache.spark.sql.RowFactory;
|
|
import org.apache.spark.sql.RowFactory;
|
|
@@ -20,8 +16,10 @@ import org.apache.spark.sql.types.DataTypes;
|
|
import org.apache.spark.sql.types.StructField;
|
|
import org.apache.spark.sql.types.StructField;
|
|
import org.apache.spark.sql.types.StructType;
|
|
import org.apache.spark.sql.types.StructType;
|
|
|
|
|
|
-import java.util.*;
|
|
|
|
-import java.util.stream.Collectors;
|
|
|
|
|
|
+import java.util.ArrayList;
|
|
|
|
+import java.util.HashMap;
|
|
|
|
+import java.util.List;
|
|
|
|
+import java.util.Map;
|
|
|
|
|
|
/**
|
|
/**
|
|
* @author dyp
|
|
* @author dyp
|
|
@@ -93,10 +91,8 @@ public class XGBoostTrain {
|
|
|
|
|
|
Object[] v = new Object[features.length + 1];
|
|
Object[] v = new Object[features.length + 1];
|
|
v[0] = label;
|
|
v[0] = label;
|
|
- v[0] = RandomUtils.nextInt(0, 2);
|
|
|
|
- double[] values = new double[features.length];
|
|
|
|
|
|
+ //v[0] = RandomUtils.nextInt(0, 2);
|
|
for (int i = 0; i < features.length; i++) {
|
|
for (int i = 0; i < features.length; i++) {
|
|
- values[i] = map.getOrDefault(features[i], 0.0d);
|
|
|
|
v[i + 1] = map.getOrDefault(features[i], 0.0d);
|
|
v[i + 1] = map.getOrDefault(features[i], 0.0d);
|
|
}
|
|
}
|
|
|
|
|