فهرست منبع

feat:添加ros回归模型

zhaohaipeng 1 ماه پیش
والد
کامیت
eb92f1f4d9

+ 11 - 1
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/recsys_01_ros_reg_xgb_train.scala

@@ -75,8 +75,9 @@ object recsys_01_ros_reg_xgb_train {
     fields = fields ++ Array(
       DataTypes.createStructField("logKey", DataTypes.StringType, true)
     )
-
     val schema = DataTypes.createStructType(fields)
+
+
     val trainDataSet: Dataset[Row] = spark.createDataFrame(trainData, schema)
     val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
     val xgbInput = vectorAssembler.transform(trainDataSet).select("features", "label").persist()
@@ -108,7 +109,14 @@ object recsys_01_ros_reg_xgb_train {
       features
     )
     val testDataSet = spark.createDataFrame(testData, schema)
+    println("recsys ros testDataSet schema")
+    testDataSet.printSchema()
+
     val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features", "label", "logKey")
+    println("recsys ros testDataSetTrans schema")
+    testDataSetTrans.printSchema()
+
+
     val predictions = model.transform(testDataSetTrans)
 
     // 保存评估结果
@@ -162,7 +170,9 @@ object recsys_01_ros_reg_xgb_train {
     data.map(r => {
       val line: Array[String] = StringUtils.split(r, '\t')
       val logKey = line(0)
+
       val label: Double = NumberUtils.toDouble(line(1))
+
       val map: util.Map[String, Double] = new util.HashMap[String, Double]
       for (i <- 2 until line.length) {
         val fv: Array[String] = StringUtils.split(line(i), ':')