|
@@ -49,7 +49,7 @@ object XGBoostTrain {
|
|
|
v(0) = label
|
|
|
|
|
|
for (index <- featureNameList.indices) {
|
|
|
- v(index + 1) = map.getOrElse(featureNameList(1), 0.0)
|
|
|
+ v(index + 1) = map.getOrElse(featureNameList(index), 0.0)
|
|
|
}
|
|
|
|
|
|
Row.fromSeq(v)
|
|
@@ -58,7 +58,7 @@ object XGBoostTrain {
|
|
|
|
|
|
val fields = Seq(
|
|
|
StructField("label", DataTypes.IntegerType, true)
|
|
|
- ) ++ featureNameList.map(f => StructField(f.toString, DataTypes.DoubleType, true))
|
|
|
+ ) ++ featureNameList.map(f => StructField(f, DataTypes.DoubleType, true))
|
|
|
|
|
|
val dataset = spark.createDataFrame(rowRDD, StructType(fields))
|
|
|
|