丁云鹏 8 月之前
父節點
當前提交
5819bcb850

+ 7 - 4
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/service/XGBoostService.java

@@ -74,7 +74,8 @@ public class XGBoostService {
     public void predict(String[] args) {
         try {
 
-            Dataset<Row> assembledData = dataset("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00098.gz");
+            Dataset<Row> assembledData =
+                    dataset("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00098.gz").select("features");
             log.info("测试样本 show");
             assembledData.show();
 
@@ -91,6 +92,8 @@ public class XGBoostService {
             // 显示预测结果
             XGBoostClassificationModel model2 = XGBoostClassificationModel.load("file://" + destDir);
             model2.setMissing(0.0f);
+            model2.setFeaturesCol("features");
+
             Dataset<Row> predictions = model2.transform(assembledData);
             predictions.select("label", "prediction", "rawPrediction", "probability", "features").show(500);
 
@@ -147,7 +150,7 @@ public class XGBoostService {
 
         JavaRDD<Row> rowRDD = rdd.map(s -> {
             String[] line = StringUtils.split(s, '\t');
-            int label = NumberUtils.toInt(line[0]);
+            double label = NumberUtils.toDouble(line[0]);
             // 选特征
             Map<String, Double> map = new HashMap<>();
             for (int i = 1; i < line.length; i++) {
@@ -167,7 +170,7 @@ public class XGBoostService {
         log.info("rowRDD count {}", rowRDD.count());
         // 将 JavaRDD<Row> 转换为 Dataset<Row>
         List<StructField> fields = new ArrayList<>();
-        fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
+        fields.add(DataTypes.createStructField("label", DataTypes.DoubleType, true));
         for (String f : features) {
             fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
         }
@@ -178,7 +181,7 @@ public class XGBoostService {
                 .setInputCols(features)
                 .setOutputCol("features");
 
-        Dataset<Row> assembledData = assembler.transform(dataset);
+        Dataset<Row> assembledData = assembler.transform(dataset).select("features", "label");
         assembledData.show();
         return assembledData;
     }

+ 7 - 4
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrainLocalTest.java

@@ -56,8 +56,11 @@ public class XGBoostTrainLocalTest {
             // 显示预测结果
             Dataset<Row> predictData = dataset("/Users/dingyunpeng/Desktop/part-00099.gz");
             model.setMissing(0.0f);
+            model.setFeaturesCol("features");
+            model.setTreeLimit(100);
+
             Dataset<Row> predictions = model.transform(predictData);
-            predictions.select("label", "prediction", "rawPrediction", "probability", "features").show();
+            predictions.select("prediction", "rawPrediction", "probability", "features").show();
 
             // 计算AUC
             Dataset<Row> selected = predictions.select("label", "rawPrediction");
@@ -104,7 +107,7 @@ public class XGBoostTrainLocalTest {
 
         JavaRDD<Row> rowRDD = rdd.map(s -> {
             String[] line = StringUtils.split(s, '\t');
-            int label = NumberUtils.toInt(line[0]);
+            double label = NumberUtils.toDouble(line[0]);
             // 选特征
             Map<String, Double> map = new HashMap<>();
             for (int i = 1; i < line.length; i++) {
@@ -124,7 +127,7 @@ public class XGBoostTrainLocalTest {
         log.info("rowRDD count {}", rowRDD.count());
         // 将 JavaRDD<Row> 转换为 Dataset<Row>
         List<StructField> fields = new ArrayList<>();
-        fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
+        fields.add(DataTypes.createStructField("label", DataTypes.DoubleType, true));
         for (String f : features) {
             fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
         }
@@ -135,7 +138,7 @@ public class XGBoostTrainLocalTest {
                 .setInputCols(features)
                 .setOutputCol("features");
 
-        Dataset<Row> assembledData = assembler.transform(dataset);
+        Dataset<Row> assembledData = assembler.transform(dataset).select("features", "label");
         assembledData.show();
         return assembledData;
     }