|
@@ -9,6 +9,7 @@ import org.apache.commons.lang3.RandomUtils;
|
|
|
import org.apache.commons.lang3.StringUtils;
|
|
|
import org.apache.spark.api.java.JavaRDD;
|
|
|
import org.apache.spark.api.java.JavaSparkContext;
|
|
|
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
|
|
|
import org.apache.spark.ml.feature.VectorAssembler;
|
|
|
import org.apache.spark.sql.Dataset;
|
|
|
import org.apache.spark.sql.Row;
|
|
@@ -112,6 +113,16 @@ public class XGBoostPredict {
|
|
|
Dataset<Row> predictions = model2.transform(assembledData);
|
|
|
predictions.select("label", "prediction", "features", "rawPrediction", "probability").show(500);
|
|
|
|
|
|
+ // 计算AUC
|
|
|
+ Dataset<Row> selected = predictions.select("label", "probability");
|
|
|
+ BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
|
|
|
+ .setLabelCol("label")
|
|
|
+ .setRawPredictionCol("probability")
|
|
|
+ .setMetricName("areaUnderROC");
|
|
|
+ double auc = evaluator.evaluate(selected);
|
|
|
+
|
|
|
+ log.info("AUC: {}", auc);
|
|
|
+
|
|
|
spark.close();
|
|
|
|
|
|
} catch (Throwable e) {
|