jch 4 miesięcy temu
rodzic
commit
eccbfd4999

+ 9 - 14
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/pred_recsys_61_xgb_nor_hdfsfile_20241209.scala

@@ -64,7 +64,7 @@ object pred_recsys_61_xgb_nor_hdfsfile_20241209 {
 
     val testDataSet = spark.createDataFrame(testData, schema)
     val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features", "label")
-    val predictions = model.transform(testDataSetTrans)
+    val predictions = model.transform(testDataSetTrans).persist()
 
     val saveData = predictions.select("label", "prediction").rdd
       .map(r => {
@@ -79,15 +79,19 @@ object pred_recsys_61_xgb_nor_hdfsfile_20241209 {
       println("路径不合法,无法写入:" + hdfsPath)
     }
 
-    val evaluator = new RegressionEvaluator()
+    val rmseEvaluator = new RegressionEvaluator()
       .setLabelCol("label")
       .setPredictionCol("prediction")
       .setMetricName("rmse")
-    val rmse = evaluator.evaluate(predictions.select("label", "prediction"))
-    val selfRmse = calRMSE(predictions.select("label", "prediction").rdd)
+    val maeEvaluator = new RegressionEvaluator()
+      .setLabelCol("label")
+      .setPredictionCol("prediction")
+      .setMetricName("mae")
+    val rmse = rmseEvaluator.evaluate(predictions.select("label", "prediction"))
+    val mae = maeEvaluator.evaluate(predictions.select("label", "prediction"))
     val rmsle = calRMSLE(predictions.select("label", "prediction").rdd)
     printf("recsys nor:rmse: %.6f\n", rmse)
-    printf("recsys nor:selfRmse:%.6f\n", selfRmse)
+    printf("recsys nor:mae: %.6f\n", mae)
     printf("recsys nor:rmsle: %.6f\n", rmsle)
 
     println("---------------------------------\n")
@@ -113,15 +117,6 @@ object pred_recsys_61_xgb_nor_hdfsfile_20241209 {
     })
   }
 
-  def calRMSE(evalRdd: RDD[Row]): Double = {
-    val sleRdd = evalRdd.map(raw => {
-      val label = raw.get(0).toString.toDouble
-      val pred = raw.get(1).toString.toDouble
-      math.pow(pred - label, 2)
-    })
-    math.sqrt(sleRdd.sum() / sleRdd.count())
-  }
-
   def calRMSLE(evalRdd: RDD[Row]): Double = {
     val sleRdd = evalRdd.map(raw => {
       val label = raw.get(0).toString.toDouble

+ 2 - 2
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/train_recsys_61_xgb_nor_20241209.scala

@@ -32,8 +32,8 @@ object train_recsys_61_xgb_nor_20241209 {
     val max_depth = param.getOrElse("max_depth", "5").toInt
     val num_round = param.getOrElse("num_round", "100").toInt
     val num_worker = param.getOrElse("num_worker", "20").toInt
-    val func_object = param.getOrElse("func_object", "reg:squaredlogerror")
-    val func_metric = param.getOrElse("func_metric", "rmsle")
+    val func_object = param.getOrElse("func_object", "reg:squarederror")
+    val func_metric = param.getOrElse("func_metric", "rmse")
     val repartition = param.getOrElse("repartition", "20").toInt
     val modelPath = param.getOrElse("modelPath", "/dw/recommend/model/61_recsys_nor_model/model_xgb")
     val modelFile = param.getOrElse("modelFile", "model_xgb_for_recsys_nor.tar.gz")