|
@@ -4,6 +4,7 @@ import lombok.extern.slf4j.Slf4j;
|
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
|
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
|
|
|
import org.apache.commons.lang.math.NumberUtils;
|
|
|
+import org.apache.commons.lang3.RandomUtils;
|
|
|
import org.apache.commons.lang3.StringUtils;
|
|
|
import org.apache.spark.api.java.JavaRDD;
|
|
|
import org.apache.spark.api.java.JavaSparkContext;
|
|
@@ -30,9 +31,7 @@ public class XGBoostTrainLocalTest {
|
|
|
|
|
|
public static void main(String[] args) {
|
|
|
try {
|
|
|
- Dataset<Row> assembledData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
|
|
|
-
|
|
|
- // 创建 XGBoostClassifier 对象
|
|
|
+ // 训练
|
|
|
XGBoostClassifier xgbClassifier = new XGBoostClassifier()
|
|
|
.setEta(0.01f)
|
|
|
.setSubsample(0.8)
|
|
@@ -47,22 +46,23 @@ public class XGBoostTrainLocalTest {
|
|
|
.setNthread(1)
|
|
|
.setNumRound(100)
|
|
|
.setNumWorkers(1);
|
|
|
+ Dataset<Row> trainData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
|
|
|
+ trainData.show();
|
|
|
+ XGBoostClassificationModel model = xgbClassifier.fit(trainData);
|
|
|
|
|
|
|
|
|
- // 显示预测结果
|
|
|
+ // 预测
|
|
|
Dataset<Row> predictData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
|
|
|
-
|
|
|
- // 训练模型
|
|
|
- XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
|
|
|
+ model.setFeaturesCol("features").setMissing(0.0f);
|
|
|
Dataset<Row> predictions = model.transform(predictData);
|
|
|
- predictions.select("prediction", "rawPrediction", "probability", "features").show();
|
|
|
+ predictions.show();
|
|
|
+
|
|
|
|
|
|
// 计算AUC
|
|
|
Dataset<Row> selected = predictions.select("label", "rawPrediction");
|
|
|
BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
|
|
|
.setLabelCol("label")
|
|
|
- .setRawPredictionCol("rawPrediction")
|
|
|
- .setMetricName("areaUnderROC");
|
|
|
+ .setRawPredictionCol("rawPrediction");
|
|
|
double auc = evaluator.evaluate(selected);
|
|
|
|
|
|
log.info("AUC: {}", auc);
|
|
@@ -73,21 +73,31 @@ public class XGBoostTrainLocalTest {
|
|
|
}
|
|
|
|
|
|
private static Dataset<Row> dataset(String path) {
|
|
|
- String[] features = {"cpa",
|
|
|
- "b2_12h_ctr",
|
|
|
- "b2_12h_ctcvr",
|
|
|
- "b2_12h_cvr",
|
|
|
- "b2_12h_conver",
|
|
|
- "b2_12h_click",
|
|
|
- "b2_12h_conver*log(view)",
|
|
|
- "b2_12h_conver*ctcvr",
|
|
|
- "b2_7d_ctr",
|
|
|
- "b2_7d_ctcvr",
|
|
|
- "b2_7d_cvr",
|
|
|
- "b2_7d_conver",
|
|
|
- "b2_7d_click",
|
|
|
- "b2_7d_conver*log(view)",
|
|
|
- "b2_7d_conver*ctcvr"
|
|
|
+ String[] features = {
|
|
|
+ "cpa",
|
|
|
+ "b2_1h_ctr",
|
|
|
+ "b2_1h_ctcvr",
|
|
|
+ "b2_1h_cvr",
|
|
|
+ "b2_1h_conver",
|
|
|
+ "b2_1h_click",
|
|
|
+ "b2_1h_conver*log(view)",
|
|
|
+ "b2_1h_conver*ctcvr",
|
|
|
+ "b2_2h_ctr",
|
|
|
+ "b2_2h_ctcvr",
|
|
|
+ "b2_2h_cvr",
|
|
|
+ "b2_2h_conver",
|
|
|
+ "b2_2h_click",
|
|
|
+ "b2_2h_conver*log(view)",
|
|
|
+ "b2_2h_conver*ctcvr",
|
|
|
+ "b2_3h_ctr",
|
|
|
+ "b2_3h_ctcvr",
|
|
|
+ "b2_3h_cvr",
|
|
|
+ "b2_3h_conver",
|
|
|
+ "b2_3h_click",
|
|
|
+ "b2_3h_conver*log(view)",
|
|
|
+ "b2_3h_conver*ctcvr",
|
|
|
+ "b2_6h_ctr",
|
|
|
+ "b2_6h_ctcvr"
|
|
|
};
|
|
|
|
|
|
|
|
@@ -115,7 +125,7 @@ public class XGBoostTrainLocalTest {
|
|
|
for (int i = 0; i < features.length; i++) {
|
|
|
v[i + 1] = map.getOrDefault(features[i], 0.0d);
|
|
|
}
|
|
|
-
|
|
|
+ //v[0] = (double) v[1] > 0.05 ? 1.0 : 0.0;
|
|
|
return RowFactory.create(v);
|
|
|
});
|
|
|
|
|
@@ -133,8 +143,7 @@ public class XGBoostTrainLocalTest {
|
|
|
.setInputCols(features)
|
|
|
.setOutputCol("features");
|
|
|
|
|
|
- Dataset<Row> assembledData = assembler.transform(dataset).select("features", "label");
|
|
|
- assembledData.show();
|
|
|
+ Dataset<Row> assembledData = assembler.transform(dataset);
|
|
|
return assembledData;
|
|
|
}
|
|
|
}
|