丁云鹏 8 months ago
parent
commit
c38ab79177

+ 1 - 1
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrain.java

@@ -49,7 +49,7 @@ public class XGBoostTrain {
             log.info("rowRDD count {}", rowRDD.count());
             // 将 JavaRDD<Row> 转换为 Dataset<Row>
             List<StructField> fields = new ArrayList<>();
-            fields.add(DataTypes.createStructField("label", DataTypes.StringType, true));
+            fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
             fields.add(DataTypes.createStructField("features", new ArrayType(DataTypes.DoubleType, true), true));
             StructType schema = DataTypes.createStructType(fields);
             Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);