|
@@ -1,6 +1,7 @@
|
|
|
package com.tzld.piaoquan.recommend.model.produce.xgboost;
|
|
|
|
|
|
import com.aliyun.odps.utils.StringUtils;
|
|
|
+import com.google.common.collect.Lists;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
|
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
|
|
@@ -18,15 +19,45 @@ import org.apache.spark.sql.types.StructField;
|
|
|
import org.apache.spark.sql.types.StructType;
|
|
|
|
|
|
import java.util.ArrayList;
|
|
|
+import java.util.HashMap;
|
|
|
import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
|
|
|
/**
|
|
|
* @author dyp
|
|
|
*/
|
|
|
@Slf4j
|
|
|
public class XGBoostTrain {
|
|
|
+
|
|
|
public static void main(String[] args) {
|
|
|
try {
|
|
|
+
|
|
|
+ List<String> features = Lists.newArrayList("cpa",
|
|
|
+ "b2_1h_ctr",
|
|
|
+ "b2_1h_ctcvr",
|
|
|
+ "b2_1h_cvr",
|
|
|
+ "b2_1h_conver",
|
|
|
+ "b2_1h_click",
|
|
|
+ "b2_1h_conver*log(view)",
|
|
|
+ "b2_1h_conver*ctcvr",
|
|
|
+ "b2_2h_ctr",
|
|
|
+ "b2_2h_ctcvr",
|
|
|
+ "b2_2h_cvr",
|
|
|
+ "b2_2h_conver",
|
|
|
+ "b2_2h_click",
|
|
|
+ "b2_2h_conver*log(view)",
|
|
|
+ "b2_2h_conver*ctcvr",
|
|
|
+ "b2_3h_ctr",
|
|
|
+ "b2_3h_ctcvr",
|
|
|
+ "b2_3h_cvr",
|
|
|
+ "b2_3h_conver",
|
|
|
+ "b2_3h_click",
|
|
|
+ "b2_3h_conver*log(view)",
|
|
|
+ "b2_3h_conver*ctcvr",
|
|
|
+ "b2_6h_ctr",
|
|
|
+ "b2_6h_ctcvr");
|
|
|
+
|
|
|
+
|
|
|
SparkSession spark = SparkSession.builder()
|
|
|
.appName("XGBoostTrain")
|
|
|
//.master("local")
|
|
@@ -42,11 +73,18 @@ public class XGBoostTrain {
|
|
|
int label = NumberUtils.toInt(line[0]);
|
|
|
int[] indices = new int[line.length - 1];
|
|
|
double[] values = new double[line.length - 1];
|
|
|
+
|
|
|
+ // 选特征
|
|
|
+ Map<String, Double> map = new HashMap<>();
|
|
|
for (int i = 1; i < line.length; i++) {
|
|
|
String[] fv = StringUtils.split(":");
|
|
|
- indices[i - 1] = i - 1;
|
|
|
- values[i - 1] = NumberUtils.toDouble(fv[1], 0.0);
|
|
|
+ map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
|
|
|
+ }
|
|
|
+ for (int i = 0; i < features.size(); i++) {
|
|
|
+ indices[i] = i;
|
|
|
+ values[i] = map.getOrDefault(features.get(i), 0.0);
|
|
|
}
|
|
|
+
|
|
|
SparseVector vector = new SparseVector(indices.length, indices, values);
|
|
|
return RowFactory.create(label, vector);
|
|
|
});
|
|
@@ -71,6 +109,7 @@ public class XGBoostTrain {
|
|
|
// 创建 XGBoostClassifier 对象
|
|
|
XGBoostClassifier xgbClassifier = new XGBoostClassifier()
|
|
|
.setEta(0.1f)
|
|
|
+ .setMissing(0.0f)
|
|
|
.setFeaturesCol("features")
|
|
|
.setLabelCol("label")
|
|
|
.setMaxDepth(5)
|
|
@@ -87,6 +126,7 @@ public class XGBoostTrain {
|
|
|
Dataset<Row> predictions = model.transform(testData);
|
|
|
predictions.select("label", "prediction").show();
|
|
|
|
|
|
+
|
|
|
} catch (Exception e) {
|
|
|
log.error("", e);
|
|
|
}
|