Browse Source

feat:添加ros分析脚本

zhaohaipeng 1 month ago
parent
commit
2bbd98e8cd

+ 1 - 1
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/recsys_01_ros_multi_class_xgb_train.scala

@@ -26,7 +26,7 @@ object recsys_01_ros_multi_class_xgb_train {
 
     val spark = SparkSession
       .builder()
-      .appName(this.getClass.getName + " : " + dt)
+      .appName(this.getClass.getName + ": " + dt)
       .getOrCreate()
     val sc = spark.sparkContext
 

+ 104 - 0
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/recsys_02_ros_multi_class_predict_analyse.scala

@@ -0,0 +1,104 @@
+package com.tzld.piaoquan.recommend.model
+
+import com.tzld.piaoquan.recommend.utils.ParamUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SparkSession
+
+import java.time.LocalDateTime
+import java.time.format.DateTimeFormatter
+
+object recsys_02_ros_multi_class_predict_analyse {
+  def main(args: Array[String]): Unit = {
+    val dt = DateTimeFormatter.ofPattern("yyyyMMddHHmm").format(LocalDateTime.now())
+
+    val spark = SparkSession
+      .builder()
+      .appName(this.getClass.getName + ": " + dt)
+      .getOrCreate()
+    val sc = spark.sparkContext
+
+    val param = ParamUtils.parseArgs(args)
+    val readPath = param.getOrElse("readPath", "/dw/recommend/model/44_recsys_ros_predict/")
+
+    val data = sc.textFile(readPath)
+      .map(parseLine)
+      .cache()
+
+    // 计算 AUC
+    val aucScores = computeAUC(data.map(_._1))
+
+    // 计算 accuracyRate 指标
+    val (globalAcc, perLabelAcc) = computeAccuracyRate(data.map(_._2))
+
+    // 打印结果
+    println("AUC Scores:")
+    aucScores.zipWithIndex.foreach { case (auc, index) => println(s"Label $index: AUC = $auc") }
+
+    println(s"\nGlobal Accuracy: $globalAcc")
+
+    println("\nPer Label Accuracy:")
+    perLabelAcc.zipWithIndex.foreach { case (acc, index) =>
+      println(s"Label $index: Accuracy = $acc")
+    }
+
+  }
+
+  def parseLine(line: String): (Array[(Int, Double)], Array[(Int, Int)]) = {
+    val sLine = line.split("\t")
+    val label = sLine(0).toInt
+    val scores = sLine(2).replace("[", "").replace("]", "").trim.split(",").map(_.toDouble)
+
+    // 找到最大值索引
+    val maxIndex = scores.zipWithIndex.maxBy(_._1)._2
+
+    // 生成八个数组
+    val aucs = scores.indices.map(i => (if (i == label) 1 else 0, scores(i))).toArray
+
+    // 新数组 (index, 是否是最大值)
+    val accuracyRate = scores.indices.map(i => (if (i == label) 1 else 0, if (i == maxIndex) 1 else 0)).toArray
+
+    (aucs, accuracyRate)
+  }
+
+  /** 计算每一列的 AUC */
+  def computeAUC(aucData: RDD[Array[(Int, Double)]]): Array[Double] = {
+    val numCols = 8
+    (0 until numCols).map { colIndex =>
+      val colData = aucData.map(arr => arr(colIndex)) // 取第 colIndex 列
+      val sortedData = colData.sortBy(_._2, ascending = false).collect()
+
+      val positive = sortedData.count(_._1 == 1).toDouble
+      val negative = sortedData.length - positive
+      if (positive == 0 || negative == 0) 0.5 // 避免无正样本或负样本导致 AUC 计算错误
+
+      var auc = 0.0
+      var rankSum = 0.0
+      var count = 0.0
+      sortedData.zipWithIndex.foreach { case ((label, _), rank) =>
+        if (label == 1) rankSum += rank + 1
+      }
+      auc = (rankSum - positive * (positive + 1) / 2) / (positive * negative)
+      auc
+    }.toArray
+  }
+
+  /** 计算全局 accuracy 和按 label 计算 accuracy */
+  def computeAccuracyRate(accData: RDD[Array[(Int, Int)]]): (Double, Array[Double]) = {
+    val numCols = 8
+
+    // 计算全局 accuracy (总计所有列)
+    val globalCorrect = accData.flatMap(_.filter { case (label, maxFlag) => label == 1 && maxFlag == 1 }).count().toDouble
+    val totalCount = accData.flatMap(_.map(_ => 1)).count().toDouble
+    val globalAccuracy = if (totalCount == 0) 0.0 else globalCorrect / totalCount
+
+    // 计算按 label 的 accuracy
+    val perLabelAccuracy = (0 until numCols).map { colIndex =>
+      val colData = accData.map(arr => arr(colIndex))
+      val truePositive = colData.filter { case (label, maxFlag) => label == 1 && maxFlag == 1 }.count().toDouble
+      val totalPositive = colData.filter(_._1 == 1).count().toDouble
+      if (totalPositive == 0) 0.0 else truePositive / totalPositive
+    }.toArray
+
+    (globalAccuracy, perLabelAccuracy)
+  }
+}