|
@@ -91,18 +91,17 @@ public class XGBoostService {
|
|
|
|
|
|
// 显示预测结果
|
|
|
XGBoostClassificationModel model2 = XGBoostClassificationModel.load("file://" + destDir);
|
|
|
- model2.setMissing(0.0f);
|
|
|
- model2.setFeaturesCol("features");
|
|
|
+ model2.setMissing(0.0f)
|
|
|
+ .setFeaturesCol("features");
|
|
|
|
|
|
Dataset<Row> predictions = model2.transform(assembledData);
|
|
|
- predictions.select("label", "prediction", "rawPrediction", "probability", "features").show(500);
|
|
|
+ predictions.select("label", "prediction", "rawPrediction", "probability", "features").show();
|
|
|
|
|
|
// 计算AUC
|
|
|
Dataset<Row> selected = predictions.select("label", "rawPrediction");
|
|
|
BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
|
|
|
.setLabelCol("label")
|
|
|
.setRawPredictionCol("rawPrediction")
|
|
|
- .setMetricName("areaUnderROC");
|
|
|
double auc = evaluator.evaluate(selected);
|
|
|
log.info("AUC: {}", auc);
|
|
|
|