|  | @@ -81,8 +81,11 @@ object recsys_01_ros_multi_class_xgb_train {
 | 
											
												
													
														|  |      val schema = DataTypes.createStructType(fields)
 |  |      val schema = DataTypes.createStructType(fields)
 | 
											
												
													
														|  |      val trainDataSet: Dataset[Row] = spark.createDataFrame(trainData, schema)
 |  |      val trainDataSet: Dataset[Row] = spark.createDataFrame(trainData, schema)
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | -    val numClasses = trainDataSet.select("label").distinct().count().toInt
 |  | 
 | 
											
												
													
														|  | -    println(s"Label标签类别数: $numClasses")
 |  | 
 | 
											
												
													
														|  | 
 |  | +    // 打印去重后的label值
 | 
											
												
													
														|  | 
 |  | +    val distinctLabels = trainDataSet.select("label").distinct().collect()
 | 
											
												
													
														|  | 
 |  | +    println(s"Label标签类别数: ${distinctLabels.length}")
 | 
											
												
													
														|  | 
 |  | +    println("Distinct labels:")
 | 
											
												
													
														|  | 
 |  | +    distinctLabels.foreach(row => println(row.getInt(0)))
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |      val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
 |  |      val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
 | 
											
												
													
														|  |      val xgbInput = vectorAssembler.transform(trainDataSet).select("features", "label")
 |  |      val xgbInput = vectorAssembler.transform(trainDataSet).select("features", "label")
 |