Bläddra i källkod

调整ros-训练方式

jch 2 månader sedan
förälder
incheckning
707bc6d633

+ 40 - 0
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/MetricUtils.scala

@@ -0,0 +1,40 @@
+package com.tzld.piaoquan.recommend.model
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Row
+
+object MetricUtils {
+  def logScale(label: Double, logType: Int, logBase: Double): Double = {
+    if (0 == logType) {
+      label
+    } else {
+      Math.log(1 + label) / Math.log(logBase)
+    }
+  }
+
+  def restoreLog(predict: Double, logType: Int, logBase: Double): Double = {
+    if (0 == logType) {
+      predict
+    } else {
+      Math.exp(predict * Math.log(logBase)) - 1
+    }
+  }
+
+  def calMAPE(evalRdd: RDD[Row]): Double = {
+    val apeRdd = evalRdd.map(raw => {
+      val label = raw.get(0).toString.toDouble
+      val pred = raw.get(1).toString.toDouble
+      math.abs(label - pred) / label
+    })
+    apeRdd.sum() / apeRdd.count()
+  }
+
+  def calRMSLE(evalRdd: RDD[Row]): Double = {
+    val sleRdd = evalRdd.map(raw => {
+      val label = raw.get(0).toString.toDouble
+      val pred = raw.get(1).toString.toDouble
+      math.pow(math.log(pred + 1) - math.log(label + 1), 2)
+    })
+    math.sqrt(sleRdd.sum() / sleRdd.count())
+  }
+}

+ 7 - 23
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/pred_recsys_61_xgb_nor_hdfsfile_20241209.scala

@@ -24,6 +24,8 @@ object pred_recsys_61_xgb_nor_hdfsfile_20241209 {
     val param = ParamUtils.parseArgs(args)
     val featureFile = param.getOrElse("featureFile", "20241209_recsys_nor_name.txt")
     val testPath = param.getOrElse("testPath", "")
+    val labelLogType = param.getOrElse("labelLogType", "0").toInt
+    val labelLogBase = param.getOrElse("labelLogBase", "2").toDouble
     val savePath = param.getOrElse("savePath", "/dw/recommend/model/61_recsys_nor_predict_data/")
     val featureFilter = param.getOrElse("featureFilter", "XXXXXX").split(",")
 
@@ -69,7 +71,7 @@ object pred_recsys_61_xgb_nor_hdfsfile_20241209 {
     val testDataSet = spark.createDataFrame(testData, schema)
     val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features", "label", "logKey", "scoresMap")
     val predictions = model.transform(testDataSetTrans)
-    val clipPrediction = getClipData(spark, predictions).persist()
+    val clipPrediction = getClipData(spark, predictions, labelLogType, labelLogBase).persist()
 
     val saveData = clipPrediction.select("label", "prediction", "clipPrediction", "logKey", "scoresMap").rdd
       .map(r => {
@@ -94,8 +96,8 @@ object pred_recsys_61_xgb_nor_hdfsfile_20241209 {
       .setMetricName("mae")
     val rmse = rmseEvaluator.evaluate(clipPrediction.select("label", "clipPrediction"))
     val mae = maeEvaluator.evaluate(clipPrediction.select("label", "clipPrediction"))
-    val mape = calMAPE(clipPrediction.select("label", "clipPrediction").rdd)
-    val rmsle = calRMSLE(clipPrediction.select("label", "clipPrediction").rdd)
+    val mape = MetricUtils.calMAPE(clipPrediction.select("label", "clipPrediction").rdd)
+    val rmsle = MetricUtils.calRMSLE(clipPrediction.select("label", "clipPrediction").rdd)
     printf("recsys nor:rmse: %.6f\n", rmse)
     printf("recsys nor:mae: %.6f\n", mae)
     printf("recsys nor:mape: %.6f\n", mape)
@@ -128,12 +130,12 @@ object pred_recsys_61_xgb_nor_hdfsfile_20241209 {
     })
   }
 
-  def getClipData(spark: SparkSession, df: DataFrame): DataFrame = {
+  def getClipData(spark: SparkSession, df: DataFrame, logType: Int, logBase: Double): DataFrame = {
     import spark.implicits._
     df.select("label", "prediction", "logKey", "scoresMap").rdd
       .map(row => {
         val label = row.getAs[Double]("label")
-        val prediction = row.getAs[Double]("prediction")
+        val prediction = MetricUtils.restoreLog(row.getAs[Double]("prediction"), logType, logBase)
         val logKey = row.getAs[String]("logKey")
         val scoresMap = row.getAs[String]("scoresMap")
         if (prediction < 1E-8) {
@@ -144,22 +146,4 @@ object pred_recsys_61_xgb_nor_hdfsfile_20241209 {
       }
       ).toDF("label", "prediction", "clipPrediction", "logKey", "scoresMap")
   }
-
-  def calMAPE(evalRdd: RDD[Row]): Double = {
-    val apeRdd = evalRdd.map(raw => {
-      val label = raw.get(0).toString.toDouble
-      val pred = raw.get(1).toString.toDouble
-      math.abs(label - pred) / label
-    })
-    apeRdd.sum() / apeRdd.count()
-  }
-
-  def calRMSLE(evalRdd: RDD[Row]): Double = {
-    val sleRdd = evalRdd.map(raw => {
-      val label = raw.get(0).toString.toDouble
-      val pred = raw.get(1).toString.toDouble
-      math.pow(math.log(pred + 1) - math.log(label + 1), 2)
-    })
-    math.sqrt(sleRdd.sum() / sleRdd.count())
-  }
 }

+ 10 - 13
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/train_recsys_61_xgb_nor_20241209.scala

@@ -27,6 +27,8 @@ object train_recsys_61_xgb_nor_20241209 {
     val testPath = param.getOrElse("testPath", "")
     val savePath = param.getOrElse("savePath", "/dw/recommend/model/61_recsys_nor_predict_data/")
     val featureFilter = param.getOrElse("featureFilter", "XXXXXX").split(",")
+    val labelLogType = param.getOrElse("labelLogType", "0").toInt
+    val labelLogBase = param.getOrElse("labelLogBase", "2").toDouble
     val eta = param.getOrElse("eta", "0.01").toDouble
     val gamma = param.getOrElse("gamma", "0.0").toDouble
     val max_depth = param.getOrElse("max_depth", "5").toInt
@@ -56,6 +58,8 @@ object train_recsys_61_xgb_nor_20241209 {
     println("features.size=" + features.length)
 
     val trainData = createData(
+      labelLogType,
+      labelLogBase,
       sc.textFile(trainPath),
       features
     )
@@ -92,12 +96,14 @@ object train_recsys_61_xgb_nor_20241209 {
 
     if (modelPath.nonEmpty && modelFile.nonEmpty) {
       model.write.overwrite.save(modelPath)
-      //      val gzPath = modelPath + "/" + modelFile
-      //      CompressUtil.compressDirectoryToGzip(modelPath, gzPath)
+      // val gzPath = modelPath + "/" + modelFile
+      // CompressUtil.compressDirectoryToGzip(modelPath, gzPath)
     }
 
     if (testPath.nonEmpty) {
       val testData = createData(
+        labelLogType,
+        labelLogBase,
         sc.textFile(testPath),
         features
       )
@@ -124,11 +130,10 @@ object train_recsys_61_xgb_nor_20241209 {
         .setMetricName("rmse")
       val rmse = evaluator.evaluate(predictions.select("label", "prediction"))
       println("recsys nor: rmse:" + rmse)
-
     }
   }
 
-  def createData(data: RDD[String], features: Array[String]): RDD[Row] = {
+  def createData(logType: Int, logBase: Double, data: RDD[String], features: Array[String]): RDD[Row] = {
     data
       .filter(r => {
         val line: Array[String] = StringUtils.split(r, '\t')
@@ -146,19 +151,11 @@ object train_recsys_61_xgb_nor_20241209 {
         }
 
         val v: Array[Any] = new Array[Any](features.length + 1)
-        v(0) = label
+        v(0) = MetricUtils.logScale(label, logType, logBase)
         for (i <- 0 until features.length) {
           v(i + 1) = map.getOrDefault(features(i), 0.0d)
         }
         Row(v: _*)
       })
   }
-
-  def clipLabel(label: Double, maxVal: Double): Double = {
-    if (label < maxVal) {
-      label
-    } else {
-      label + 2 * Math.log(label - maxVal + 1)
-    }
-  }
 }