Browse Source

feat:添加ros回归模型

zhaohaipeng 1 month ago
parent
commit
dc056f9344

+ 6 - 2
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/recsys_01_ros_reg_xgb_train.scala

@@ -68,10 +68,14 @@ object recsys_01_ros_reg_xgb_train {
     )
     println("recsys ros:train data size:" + trainData.count())
 
-    val fields = Array(
+    var fields = Array(
       DataTypes.createStructField("label", DataTypes.DoubleType, true)
     ) ++ features.map(f => DataTypes.createStructField(f, DataTypes.DoubleType, true))
 
+    fields = fields ++ Array(
+      DataTypes.createStructField("logKey", DataTypes.StringType, true)
+    )
+
     val schema = DataTypes.createStructType(fields)
     val trainDataSet: Dataset[Row] = spark.createDataFrame(trainData, schema)
     val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
@@ -104,7 +108,7 @@ object recsys_01_ros_reg_xgb_train {
       features
     )
     val testDataSet = spark.createDataFrame(testData, schema)
-    val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features", "label")
+    val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features", "label", "logKey")
     val predictions = model.transform(testDataSetTrans)
 
     // 保存评估结果