|
@@ -120,13 +120,13 @@ public class XGBoostPredict {
|
|
|
|
|
|
XGBoostClassificationModel model2 = XGBoostClassificationModel.load("file://" + destDir);
|
|
|
Dataset<Row> predictions = model2.transform(assembledData);
|
|
|
- predictions.select("label", "prediction", "features", "rawPrediction", "probability").show(500);
|
|
|
+ predictions.select("label", "prediction", "rawPrediction", "probability", "features").show(500);
|
|
|
|
|
|
|
|
|
- Dataset<Row> selected = predictions.select("label", "probability");
|
|
|
+ Dataset<Row> selected = predictions.select("label", "rawPrediction");
|
|
|
BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
|
|
|
.setLabelCol("label")
|
|
|
- .setRawPredictionCol("probability")
|
|
|
+ .setRawPredictionCol("rawPrediction")
|
|
|
.setMetricName("areaUnderROC");
|
|
|
double auc = evaluator.evaluate(selected);
|
|
|
|