丁云鹏 8 ay önce
ebeveyn
işleme
9e2c185cb7

+ 3 - 4
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/service/XGBoostService.java

@@ -91,18 +91,17 @@ public class XGBoostService {
 
             // 显示预测结果
             XGBoostClassificationModel model2 = XGBoostClassificationModel.load("file://" + destDir);
-            model2.setMissing(0.0f);
-            model2.setFeaturesCol("features");
+            model2.setMissing(0.0f)
+                    .setFeaturesCol("features");
 
             Dataset<Row> predictions = model2.transform(assembledData);
-            predictions.select("label", "prediction", "rawPrediction", "probability", "features").show(500);
+            predictions.select("label", "prediction", "rawPrediction", "probability", "features").show();
 
             // 计算AUC
             Dataset<Row> selected = predictions.select("label", "rawPrediction");
             BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
                     .setLabelCol("label")
                     .setRawPredictionCol("rawPrediction")
-                    .setMetricName("areaUnderROC");
             double auc = evaluator.evaluate(selected);
             log.info("AUC: {}", auc);
 

+ 2 - 7
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrainLocalTest.java

@@ -30,7 +30,6 @@ public class XGBoostTrainLocalTest {
 
     public static void main(String[] args) {
         try {
-
             Dataset<Row> assembledData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
 
             // 创建 XGBoostClassifier 对象
@@ -50,15 +49,11 @@ public class XGBoostTrainLocalTest {
                     .setNumWorkers(1);
 
 
-            // 训练模型
-            XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
-
             // 显示预测结果
             Dataset<Row> predictData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
-            model.setMissing(0.0f);
-            model.setFeaturesCol("features");
-            model.setTreeLimit(100);
 
+            // 训练模型
+            XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
             Dataset<Row> predictions = model.transform(predictData);
             predictions.select("prediction", "rawPrediction", "probability", "features").show();