|
@@ -62,8 +62,9 @@ object train_01_xgb_ad_20240808{
|
|
|
println("zhangbo:train data size:" + trainData.count())
|
|
|
|
|
|
val fields = Array(
|
|
|
- DataTypes.createStructField("logKey", DataTypes.StringType, true),
|
|
|
- DataTypes.createStructField("label", DataTypes.IntegerType, true)
|
|
|
+ DataTypes.createStructField("label", DataTypes.IntegerType, true),
|
|
|
+ DataTypes.createStructField("logKey", DataTypes.StringType, true)
|
|
|
+
|
|
|
) ++ features.map(f => DataTypes.createStructField(f, DataTypes.DoubleType, true))
|
|
|
val schema = DataTypes.createStructType(fields)
|
|
|
val trainDataSet: Dataset[Row] = spark.createDataFrame(trainData, schema)
|
|
@@ -138,8 +139,8 @@ object train_01_xgb_ad_20240808{
|
|
|
}
|
|
|
})
|
|
|
val res = new ArrayBuffer[Any]()
|
|
|
- res.add(cid)
|
|
|
res.add(label)
|
|
|
+ res.add(cid)
|
|
|
features.foreach(r =>{
|
|
|
res.add(featureMap.getOrElse(r, 0.0D))
|
|
|
})
|