|
@@ -6,7 +6,7 @@ import org.apache.commons.lang3.StringUtils
|
|
import org.apache.spark.ml.feature.VectorAssembler
|
|
import org.apache.spark.ml.feature.VectorAssembler
|
|
import org.apache.spark.rdd.RDD
|
|
import org.apache.spark.rdd.RDD
|
|
import org.apache.spark.sql.types.{DataTypes, StructField}
|
|
import org.apache.spark.sql.types.{DataTypes, StructField}
|
|
-import org.apache.spark.sql.{Row, RowFactory, SparkSession}
|
|
|
|
|
|
+import org.apache.spark.sql.{Dataset, Row, RowFactory, SparkSession}
|
|
|
|
|
|
import java.util
|
|
import java.util
|
|
|
|
|
|
@@ -25,13 +25,11 @@ object train_01_xgb_ad_20240808{
|
|
)
|
|
)
|
|
println("train data size:" + trainData.count())
|
|
println("train data size:" + trainData.count())
|
|
|
|
|
|
- val fields = new util.ArrayList[StructField]
|
|
|
|
- fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true))
|
|
|
|
- for (f <- features) {
|
|
|
|
- fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true))
|
|
|
|
- }
|
|
|
|
|
|
+ val fields = Array(
|
|
|
|
+ DataTypes.createStructField("label", DataTypes.IntegerType, true)
|
|
|
|
+ ) ++ features.map(f => DataTypes.createStructField(f, DataTypes.DoubleType, true))
|
|
val schema = DataTypes.createStructType(fields)
|
|
val schema = DataTypes.createStructType(fields)
|
|
- val trainDataSet = spark.createDataFrame(trainData, schema)
|
|
|
|
|
|
+ val trainDataSet: Dataset[Row] = spark.createDataFrame(trainData, schema)
|
|
val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
|
|
val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
|
|
val xgbInput = vectorAssembler.transform(trainDataSet).select("features","label")
|
|
val xgbInput = vectorAssembler.transform(trainDataSet).select("features","label")
|
|
val xgbParam = Map("eta" -> 0.01f,
|
|
val xgbParam = Map("eta" -> 0.01f,
|
|
@@ -78,12 +76,12 @@ object train_01_xgb_ad_20240808{
|
|
map.put(fv(0), NumberUtils.toDouble(fv(1), 0.0))
|
|
map.put(fv(0), NumberUtils.toDouble(fv(1), 0.0))
|
|
}
|
|
}
|
|
|
|
|
|
- val v: Array[Object] = new Array[Object](features.length + 1)
|
|
|
|
|
|
+ val v: Array[AnyRef] = new Array[AnyRef](features.length + 1)
|
|
v(0) = label
|
|
v(0) = label
|
|
for (i <- 0 until features.length) {
|
|
for (i <- 0 until features.length) {
|
|
v(i + 1) = map.getOrDefault(features(i), 0.0d)
|
|
v(i + 1) = map.getOrDefault(features(i), 0.0d)
|
|
}
|
|
}
|
|
- RowFactory.create(v)
|
|
|
|
|
|
+ RowFactory.create(v: _*)
|
|
})
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|