Browse Source

feat:添加ros回归模型

zhaohaipeng 1 month ago
parent
commit
a248a9ffdc

+ 4 - 0
recommend-model-produce/src/main/java/com/tzld/piaoquan/recommend/model/produce/util/RosUtil.java

@@ -35,4 +35,8 @@ public class RosUtil {
 
     }
 
+    public static Double inverseLog(double y) {
+        return Math.exp(y) - 1.0;
+    }
+
 }

+ 69 - 51
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/recsys_01_ros_reg_xgb_train.scala

@@ -1,5 +1,7 @@
 package com.tzld.piaoquan.recommend.model
 
+import com.alibaba.fastjson.JSON
+import com.tzld.piaoquan.recommend.model.produce.util.RosUtil
 import com.tzld.piaoquan.recommend.utils.{MyHdfsUtils, ParamUtils}
 import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor
 import org.apache.commons.lang.math.NumberUtils
@@ -96,60 +98,76 @@ object recsys_01_ros_reg_xgb_train {
       model.write.overwrite.save(modelPath)
     }
 
-    if (testPath.nonEmpty) {
-      val testData = createData(
-        sc.textFile(testPath),
-        features
-      )
-      val testDataSet = spark.createDataFrame(testData, schema)
-      val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features", "label")
-      val predictions = model.transform(testDataSetTrans)
-
-      println("recsys ros:columns:" + predictions.columns.mkString(",")) //[label, features, prediction]
-      val saveData = predictions.select("label", "prediction").rdd
-        .map(r => {
-          (r.get(0), r.get(1)).productIterator.mkString("\t")
-        })
-      val hdfsPath = savePath
-      if (hdfsPath.nonEmpty && hdfsPath.startsWith("/dw/recommend/model/")) {
-        println("删除路径并开始数据写入:" + hdfsPath)
-        MyHdfsUtils.delete_hdfs_path(hdfsPath)
-        saveData.repartition(repartition).saveAsTextFile(hdfsPath, classOf[GzipCodec])
-      } else {
-        println("路径不合法,无法写入:" + hdfsPath)
-      }
-      val evaluator = new RegressionEvaluator()
-        .setLabelCol("label")
-        .setPredictionCol("prediction")
-        .setMetricName("rmse")
-      val rmse = evaluator.evaluate(predictions.select("label", "prediction"))
-      println("recsys nor: rmse:" + rmse)
+    // 评测数据
+    val testData = createData(
+      sc.textFile(testPath),
+      features
+    )
+    val testDataSet = spark.createDataFrame(testData, schema)
+    val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features", "label")
+    val predictions = model.transform(testDataSetTrans)
+
+    // 保存评估结果
+    println("recsys ros:columns:" + predictions.columns.mkString(",")) //[label, features, prediction]
+    val saveData = predictions.select("label", "prediction", "logKey").rdd
+      .map(r => {
+        (r.get(0), r.get(1), r.get(2)).productIterator.mkString("\t")
+      })
+    val hdfsPath = savePath
+    if (hdfsPath.nonEmpty && hdfsPath.startsWith("/dw/recommend/model/")) {
+      println("删除路径并开始数据写入:" + hdfsPath)
+      MyHdfsUtils.delete_hdfs_path(hdfsPath)
+      saveData.repartition(repartition).saveAsTextFile(hdfsPath, classOf[GzipCodec])
+    } else {
+      println("路径不合法,无法写入:" + hdfsPath)
     }
+
+    // 计算rmse
+    val evaluator = new RegressionEvaluator()
+      .setLabelCol("label")
+      .setPredictionCol("prediction")
+      .setMetricName("rmse")
+    val rmse = evaluator.evaluate(predictions.select("label", "prediction"))
+    println("recsys nor: rmse:" + rmse)
+
+    sc.textFile(hdfsPath).map(r => {
+        val rList = r.split("\t")
+        val vid = JSON.parseObject(rList(2)).getString("vid")
+        val label = rList(0).toDouble
+        val score = rList(1).toDouble
+        (vid, (1, score))
+      }).reduceByKey {
+        case (c1, c2) => (c1._1 + c1._1, c2._2 + c2._2)
+      }.map {
+        case (vid, (all, sumScore)) =>
+          (vid, (all, sumScore, sumScore / all))
+      }.
+      collect().
+      sortBy(_._1).
+      map(_.productIterator).
+      mkString("\t").
+      foreach(println)
+
   }
 
   def createData(data: RDD[String], features: Array[String]): RDD[Row] = {
-    data
-      .filter(r => {
-        val line: Array[String] = StringUtils.split(r, '\t')
-        line.length > 10
-      })
-      .map(r => {
-        val line: Array[String] = StringUtils.split(r, '\t')
-        // val logKey = line(0)
-        val label: Double = NumberUtils.toDouble(line(1))
-        // val scoresMap = line(2)
-        val map: util.Map[String, Double] = new util.HashMap[String, Double]
-        for (i <- 3 until line.length) {
-          val fv: Array[String] = StringUtils.split(line(i), ':')
-          map.put(fv(0), NumberUtils.toDouble(fv(1), 0.0))
-        }
-
-        val v: Array[Any] = new Array[Any](features.length + 2)
-        v(0) = label
-        for (i <- 0 until features.length) {
-          v(i + 1) = map.getOrDefault(features(i), 0.0d)
-        }
-        Row(v: _*)
-      })
+    data.map(r => {
+      val line: Array[String] = StringUtils.split(r, '\t')
+      val logKey = line(0)
+      val label: Double = NumberUtils.toDouble(line(1))
+      val map: util.Map[String, Double] = new util.HashMap[String, Double]
+      for (i <- 2 until line.length) {
+        val fv: Array[String] = StringUtils.split(line(i), ':')
+        map.put(fv(0), NumberUtils.toDouble(fv(1), 0.0))
+      }
+
+      val v: Array[Any] = new Array[Any](features.length + 2)
+      v(0) = label
+      for (i <- 0 until features.length) {
+        v(i + 1) = map.getOrDefault(features(i), 0.0d)
+      }
+      v(features.length + 1) = logKey
+      Row(v: _*)
+    })
   }
 }