|
@@ -10,7 +10,6 @@ import org.apache.spark.sql.{Row, SparkSession}
|
|
|
|
|
|
import java.net.URL
|
|
|
import scala.io.Source
|
|
|
-import scala.reflect.ClassTag.Any
|
|
|
|
|
|
object XGBoostTrain {
|
|
|
def main(args: Array[String]): Unit = {
|
|
@@ -43,7 +42,7 @@ object XGBoostTrain {
|
|
|
val label = NumberUtils.toInt(line(0))
|
|
|
|
|
|
val map = line.drop(1).map { entry =>
|
|
|
- val Array(key, value) = StringUtils.split(entry, ':')
|
|
|
+ val Array(key, value) = entry.split(":")
|
|
|
key -> NumberUtils.toDouble(value, 0.0)
|
|
|
}.toMap
|
|
|
|
|
@@ -60,7 +59,7 @@ object XGBoostTrain {
|
|
|
|
|
|
val fields = Seq(
|
|
|
StructField("label", DataTypes.IntegerType, true)
|
|
|
- ) ++ featureNameFile.map(f => StructField(f.toString, DataTypes.DoubleType, true))
|
|
|
+ ) ++ featureNameList.map(f => StructField(f.toString, DataTypes.DoubleType, true))
|
|
|
|
|
|
val dataset = spark.createDataFrame(rowRDD, StructType(fields))
|
|
|
|