浏览代码

修改nor样本标签类型

jch 4 月之前
父节点
当前提交
e1eb57a69d

+ 2 - 2
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/pred_recsys_61_xgb_nor_hdfsfile_20241209.scala

@@ -48,7 +48,7 @@ object pred_recsys_61_xgb_nor_hdfsfile_20241209 {
     println("features.size=" + features.length)
 
     val fields = Array(
-      DataTypes.createStructField("label", DataTypes.IntegerType, true)
+      DataTypes.createStructField("label", DataTypes.DoubleType, true)
     ) ++ features.map(f => DataTypes.createStructField(f, DataTypes.DoubleType, true))
 
     val schema = DataTypes.createStructType(fields)
@@ -101,7 +101,7 @@ object pred_recsys_61_xgb_nor_hdfsfile_20241209 {
   def createData(data: RDD[String], features: Array[String]): RDD[Row] = {
     data.map(r => {
       val line: Array[String] = StringUtils.split(r, '\t')
-      val label: Int = NumberUtils.toInt(line(0))
+      val label: Double = NumberUtils.toDouble(line(0))
       val map: util.Map[String, Double] = new util.HashMap[String, Double]
       for (i <- 1 until line.length) {
         val fv: Array[String] = StringUtils.split(line(i), ':')

+ 2 - 2
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/train_recsys_61_xgb_nor_20241209.scala

@@ -62,7 +62,7 @@ object train_recsys_61_xgb_nor_20241209 {
     println("recsys nor:train data size:" + trainData.count())
 
     val fields = Array(
-      DataTypes.createStructField("label", DataTypes.IntegerType, true)
+      DataTypes.createStructField("label", DataTypes.DoubleType, true)
     ) ++ features.map(f => DataTypes.createStructField(f, DataTypes.DoubleType, true))
 
     val schema = DataTypes.createStructType(fields)
@@ -131,7 +131,7 @@ object train_recsys_61_xgb_nor_20241209 {
   def createData(data: RDD[String], features: Array[String]): RDD[Row] = {
     data.map(r => {
       val line: Array[String] = StringUtils.split(r, '\t')
-      val label: Int = NumberUtils.toInt(line(0))
+      val label: Double = NumberUtils.toDouble(line(0))
       val map: util.Map[String, Double] = new util.HashMap[String, Double]
       for (i <- 1 until line.length) {
         val fv: Array[String] = StringUtils.split(line(i), ':')