丁云鹏 8 månader sedan
förälder
incheckning
d36008d6ae

+ 11 - 0
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostPredict.java

@@ -9,6 +9,7 @@ import org.apache.commons.lang3.RandomUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
 import org.apache.spark.ml.feature.VectorAssembler;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
@@ -112,6 +113,16 @@ public class XGBoostPredict {
             Dataset<Row> predictions = model2.transform(assembledData);
             predictions.select("label", "prediction", "features", "rawPrediction", "probability").show(500);
 
+            // 计算AUC
+            Dataset<Row> selected = predictions.select("label", "probability");
+            BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
+                    .setLabelCol("label")
+                    .setRawPredictionCol("probability")
+                    .setMetricName("areaUnderROC");
+            double auc = evaluator.evaluate(selected);
+
+            log.info("AUC: {}", auc);
+
             spark.close();
 
         } catch (Throwable e) {