丁云鹏 8 months ago
parent
commit
3832eb2366

+ 12 - 4
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/xgboost/XGBoostTrain.java

@@ -11,6 +11,13 @@ import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
 import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.types.ArrayType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+import java.util.ArrayList;
+import java.util.List;
 
 /**
  * @author dyp
@@ -41,10 +48,11 @@ public class XGBoostTrain {
             });
             log.info("rowRDD count {}", rowRDD.count());
             // 将 JavaRDD<Row> 转换为 Dataset<Row>
-            Dataset<Row> dataset = spark.createDataFrame(rowRDD, Row.class);
-
-            // 如果需要,可以添加列名
-            dataset = dataset.toDF("label", "features");
+            List<StructField> fields = new ArrayList<>();
+            fields.add(DataTypes.createStructField("label", DataTypes.StringType, true));
+            fields.add(DataTypes.createStructField("features", new ArrayType(DataTypes.DoubleType, true), true));
+            StructType schema = DataTypes.createStructType(fields);
+            Dataset<Row> dataset = spark.createDataFrame(rowRDD, schema);
 
             // 划分训练集和测试集
             Dataset<Row>[] splits = dataset.randomSplit(new double[]{0.7, 0.3});