|
@@ -81,8 +81,11 @@ object recsys_01_ros_multi_class_xgb_train {
|
|
|
val schema = DataTypes.createStructType(fields)
|
|
|
val trainDataSet: Dataset[Row] = spark.createDataFrame(trainData, schema)
|
|
|
|
|
|
- val numClasses = trainDataSet.select("label").distinct().count().toInt
|
|
|
- println(s"Label标签类别数: $numClasses")
|
|
|
+ // 打印去重后的label值
|
|
|
+ val distinctLabels = trainDataSet.select("label").distinct().collect()
|
|
|
+ println(s"Label标签类别数: ${distinctLabels.length}")
|
|
|
+ println("Distinct labels:")
|
|
|
+ distinctLabels.foreach(row => println(row.getInt(0)))
|
|
|
|
|
|
val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
|
|
|
val xgbInput = vectorAssembler.transform(trainDataSet).select("features", "label")
|