|
@@ -5,9 +5,11 @@ import lombok.extern.slf4j.Slf4j;
|
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
|
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
|
|
|
import org.apache.commons.lang.math.NumberUtils;
|
|
|
+import org.apache.commons.lang3.RandomUtils;
|
|
|
import org.apache.commons.lang3.StringUtils;
|
|
|
import org.apache.spark.api.java.JavaRDD;
|
|
|
import org.apache.spark.api.java.JavaSparkContext;
|
|
|
+import org.apache.spark.ml.feature.VectorAssembler;
|
|
|
import org.apache.spark.ml.linalg.SparseVector;
|
|
|
import org.apache.spark.ml.linalg.VectorUDT;
|
|
|
import org.apache.spark.sql.Dataset;
|
|
@@ -18,10 +20,8 @@ import org.apache.spark.sql.types.DataTypes;
|
|
|
import org.apache.spark.sql.types.StructField;
|
|
|
import org.apache.spark.sql.types.StructType;
|
|
|
|
|
|
-import java.util.ArrayList;
|
|
|
-import java.util.HashMap;
|
|
|
-import java.util.List;
|
|
|
-import java.util.Map;
|
|
|
+import java.util.*;
|
|
|
+import java.util.stream.Collectors;
|
|
|
|
|
|
/**
|
|
|
* @author dyp
|
|
@@ -32,30 +32,22 @@ public class XGBoostTrain {
|
|
|
public static void main(String[] args) {
|
|
|
try {
|
|
|
|
|
|
- List<String> features = Lists.newArrayList("cpa",
|
|
|
- "b2_1h_ctr",
|
|
|
- "b2_1h_ctcvr",
|
|
|
- "b2_1h_cvr",
|
|
|
- "b2_1h_conver",
|
|
|
- "b2_1h_click",
|
|
|
- "b2_1h_conver*log(view)",
|
|
|
- "b2_1h_conver*ctcvr",
|
|
|
- "b2_2h_ctr",
|
|
|
- "b2_2h_ctcvr",
|
|
|
- "b2_2h_cvr",
|
|
|
- "b2_2h_conver",
|
|
|
- "b2_2h_click",
|
|
|
- "b2_2h_conver*log(view)",
|
|
|
- "b2_2h_conver*ctcvr",
|
|
|
- "b2_3h_ctr",
|
|
|
- "b2_3h_ctcvr",
|
|
|
- "b2_3h_cvr",
|
|
|
- "b2_3h_conver",
|
|
|
- "b2_3h_click",
|
|
|
- "b2_3h_conver*log(view)",
|
|
|
- "b2_3h_conver*ctcvr",
|
|
|
- "b2_6h_ctr",
|
|
|
- "b2_6h_ctcvr");
|
|
|
+ String[] features = {"cpa",
|
|
|
+ "b2_12h_ctr",
|
|
|
+ "b2_12h_ctcvr",
|
|
|
+ "b2_12h_cvr",
|
|
|
+ "b2_12h_conver",
|
|
|
+ "b2_12h_click",
|
|
|
+ "b2_12h_conver*log(view)",
|
|
|
+ "b2_12h_conver*ctcvr",
|
|
|
+ "b2_7d_ctr",
|
|
|
+ "b2_7d_ctcvr",
|
|
|
+ "b2_7d_cvr",
|
|
|
+ "b2_7d_conver",
|
|
|
+ "b2_7d_click",
|
|
|
+ "b2_7d_conver*log(view)",
|
|
|
+ "b2_7d_conver*ctcvr"
|
|
|
+ };
|
|
|
|
|
|
|
|
|
SparkSession spark = SparkSession.builder()
|
|
@@ -69,6 +61,26 @@ public class XGBoostTrain {
|
|
|
JavaRDD<String> rdd = jsc.textFile(file);
|
|
|
|
|
|
// 将 RDD[LabeledPoint] 转换为 JavaRDD<Row>
|
|
|
+// JavaRDD<Row> rowRDD = rdd.map(s -> {
|
|
|
+// String[] line = StringUtils.split(s, '\t');
|
|
|
+// int label = NumberUtils.toInt(line[0]);
|
|
|
+// // 选特征
|
|
|
+// Map<String, Double> map = new HashMap<>();
|
|
|
+// for (int i = 1; i < line.length; i++) {
|
|
|
+// String[] fv = StringUtils.split(line[i], ':');
|
|
|
+// map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
|
|
|
+// }
|
|
|
+//
|
|
|
+// int[] indices = new int[features.length];
|
|
|
+// double[] values = new double[features.length];
|
|
|
+// for (int i = 0; i < features.length; i++) {
|
|
|
+// indices[i] = i;
|
|
|
+// values[i] = map.getOrDefault(features[i], 0.0);
|
|
|
+// }
|
|
|
+// SparseVector vector = new SparseVector(indices.length, indices, values);
|
|
|
+// return RowFactory.create(label, vector);
|
|
|
+// });
|
|
|
+
|
|
|
JavaRDD<Row> rowRDD = rdd.map(s -> {
|
|
|
String[] line = StringUtils.split(s, '\t');
|
|
|
int label = NumberUtils.toInt(line[0]);
|
|
@@ -79,29 +91,40 @@ public class XGBoostTrain {
|
|
|
map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
|
|
|
}
|
|
|
|
|
|
- int[] indices = new int[features.size()];
|
|
|
- double[] values = new double[features.size()];
|
|
|
- for (int i = 0; i < features.size(); i++) {
|
|
|
- indices[i] = i;
|
|
|
- values[i] = map.getOrDefault(features.get(i), 0.0);
|
|
|
+ Object[] v = new Object[features.length + 1];
|
|
|
+ v[0] = label;
|
|
|
+ v[0] = RandomUtils.nextInt(0, 2);
|
|
|
+ double[] values = new double[features.length];
|
|
|
+ for (int i = 0; i < features.length; i++) {
|
|
|
+ values[i] = map.getOrDefault(features[i], 0.0d);
|
|
|
+ v[i + 1] = map.getOrDefault(features[i], 0.0d);
|
|
|
}
|
|
|
- SparseVector vector = new SparseVector(indices.length, indices, values);
|
|
|
- return RowFactory.create(label, vector);
|
|
|
+
|
|
|
+ return RowFactory.create(v);
|
|
|
});
|
|
|
|
|
|
log.info("rowRDD count {}", rowRDD.count());
|
|
|
// 将 JavaRDD<Row> 转换为 Dataset<Row>
|
|
|
List<StructField> fields = new ArrayList<>();
|
|
|
fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
|
|
|
- fields.add(DataTypes.createStructField("features", new VectorUDT(), true));
|
|
|
+ for (String f : features) {
|
|
|
+ fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
|
|
|
+ }
|
|
|
StructType schema = DataTypes.createStructType(fields);
|
|
|
Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
|
|
|
|
|
|
+ VectorAssembler assembler = new VectorAssembler()
|
|
|
+ .setInputCols(features)
|
|
|
+ .setOutputCol("features");
|
|
|
|
|
|
+ Dataset<Row> assembledData = assembler.transform(dataset);
|
|
|
+ assembledData.show();
|
|
|
// 划分训练集和测试集
|
|
|
- Dataset<Row>[] splits = dataset.randomSplit(new double[]{0.7, 0.3});
|
|
|
+ Dataset<Row>[] splits = assembledData.randomSplit(new double[]{0.7, 0.3});
|
|
|
Dataset<Row> trainData = splits[0];
|
|
|
+ trainData.show();
|
|
|
Dataset<Row> testData = splits[1];
|
|
|
+ testData.show();
|
|
|
|
|
|
// 参数
|
|
|
|
|
@@ -124,7 +147,7 @@ public class XGBoostTrain {
|
|
|
|
|
|
// 显示预测结果
|
|
|
Dataset<Row> predictions = model.transform(testData);
|
|
|
- predictions.select("label", "prediction").show(30000);
|
|
|
+ predictions.show(100);
|
|
|
|
|
|
|
|
|
} catch (Throwable e) {
|