丁云鹏 9 months ago
parent
commit
5902554448

+ 3 - 3
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostPredict.java

@@ -120,13 +120,13 @@ public class XGBoostPredict {
 
 
             XGBoostClassificationModel model2 = XGBoostClassificationModel.load("file://" + destDir);
             XGBoostClassificationModel model2 = XGBoostClassificationModel.load("file://" + destDir);
             Dataset<Row> predictions = model2.transform(assembledData);
             Dataset<Row> predictions = model2.transform(assembledData);
-            predictions.select("label", "prediction", "features", "rawPrediction", "probability").show(500);
+            predictions.select("label", "prediction", "rawPrediction", "probability", "features").show(500);
 
 
             // 计算AUC
             // 计算AUC
-            Dataset<Row> selected = predictions.select("label", "probability");
+            Dataset<Row> selected = predictions.select("label", "rawPrediction");
             BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
             BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
                     .setLabelCol("label")
                     .setLabelCol("label")
-                    .setRawPredictionCol("probability")
+                    .setRawPredictionCol("rawPrediction")
                     .setMetricName("areaUnderROC");
                     .setMetricName("areaUnderROC");
             double auc = evaluator.evaluate(selected);
             double auc = evaluator.evaluate(selected);
 
 

+ 3 - 3
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrainLocalTest.java

@@ -59,13 +59,13 @@ public class XGBoostTrainLocalTest {
             // 显示预测结果
             // 显示预测结果
             Dataset<Row> predictData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
             Dataset<Row> predictData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
             Dataset<Row> predictions = model.transform(predictData);
             Dataset<Row> predictions = model.transform(predictData);
-            predictions.select("label", "prediction", "features", "rawPrediction", "probability").show(500);
+            predictions.select("label", "prediction", "rawPrediction", "probability", "features").show();
 
 
             // 计算AUC
             // 计算AUC
-            Dataset<Row> selected = predictions.select("label", "probability");
+            Dataset<Row> selected = predictions.select("label", "rawPrediction");
             BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
             BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
                     .setLabelCol("label")
                     .setLabelCol("label")
-                    .setRawPredictionCol("probability")
+                    .setRawPredictionCol("rawPrediction")
                     .setMetricName("areaUnderROC");
                     .setMetricName("areaUnderROC");
             double auc = evaluator.evaluate(selected);
             double auc = evaluator.evaluate(selected);