Browse Source

scala train

zhangbo 8 months ago
parent
commit
f54cfd3235

+ 7 - 9
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/train_01_xgb_ad_20240808.scala

@@ -6,7 +6,7 @@ import org.apache.commons.lang3.StringUtils
 import org.apache.spark.ml.feature.VectorAssembler
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.types.{DataTypes, StructField}
-import org.apache.spark.sql.{Row, RowFactory, SparkSession}
+import org.apache.spark.sql.{Dataset, Row, RowFactory, SparkSession}
 
 import java.util
 
@@ -25,13 +25,11 @@ object train_01_xgb_ad_20240808{
     )
     println("train data size:" + trainData.count())
 
-    val fields = new util.ArrayList[StructField]
-    fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true))
-    for (f <- features) {
-      fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true))
-    }
+    val fields = Array(
+      DataTypes.createStructField("label", DataTypes.IntegerType, true)
+    ) ++ features.map(f => DataTypes.createStructField(f, DataTypes.DoubleType, true))
     val schema = DataTypes.createStructType(fields)
-    val trainDataSet = spark.createDataFrame(trainData, schema)
+    val trainDataSet: Dataset[Row] = spark.createDataFrame(trainData, schema)
     val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
     val xgbInput = vectorAssembler.transform(trainDataSet).select("features","label")
     val xgbParam = Map("eta" -> 0.01f,
@@ -78,12 +76,12 @@ object train_01_xgb_ad_20240808{
         map.put(fv(0), NumberUtils.toDouble(fv(1), 0.0))
       }
 
-      val v: Array[Object] = new Array[Object](features.length + 1)
+      val v: Array[AnyRef] = new Array[AnyRef](features.length + 1)
       v(0) = label
       for (i <- 0 until features.length) {
         v(i + 1) = map.getOrDefault(features(i), 0.0d)
       }
-      RowFactory.create(v)
+      RowFactory.create(v: _*)
     })
   }
 }