|
@@ -0,0 +1,104 @@
|
|
|
+package com.tzld.piaoquan.recommend.model
|
|
|
+
|
|
|
+import com.tzld.piaoquan.recommend.utils.ParamUtils
|
|
|
+import org.apache.spark.rdd.RDD
|
|
|
+import org.apache.spark.sql.SparkSession
|
|
|
+
|
|
|
+import java.time.LocalDateTime
|
|
|
+import java.time.format.DateTimeFormatter
|
|
|
+
|
|
|
+object recsys_02_ros_multi_class_predict_analyse {
|
|
|
+ def main(args: Array[String]): Unit = {
|
|
|
+ val dt = DateTimeFormatter.ofPattern("yyyyMMddHHmm").format(LocalDateTime.now())
|
|
|
+
|
|
|
+ val spark = SparkSession
|
|
|
+ .builder()
|
|
|
+ .appName(this.getClass.getName + ": " + dt)
|
|
|
+ .getOrCreate()
|
|
|
+ val sc = spark.sparkContext
|
|
|
+
|
|
|
+ val param = ParamUtils.parseArgs(args)
|
|
|
+ val readPath = param.getOrElse("readPath", "/dw/recommend/model/44_recsys_ros_predict/")
|
|
|
+
|
|
|
+ val data = sc.textFile(readPath)
|
|
|
+ .map(parseLine)
|
|
|
+ .cache()
|
|
|
+
|
|
|
+ // 计算 AUC
|
|
|
+ val aucScores = computeAUC(data.map(_._1))
|
|
|
+
|
|
|
+ // 计算 accuracyRate 指标
|
|
|
+ val (globalAcc, perLabelAcc) = computeAccuracyRate(data.map(_._2))
|
|
|
+
|
|
|
+ // 打印结果
|
|
|
+ println("AUC Scores:")
|
|
|
+ aucScores.zipWithIndex.foreach { case (auc, index) => println(s"Label $index: AUC = $auc") }
|
|
|
+
|
|
|
+ println(s"\nGlobal Accuracy: $globalAcc")
|
|
|
+
|
|
|
+ println("\nPer Label Accuracy:")
|
|
|
+ perLabelAcc.zipWithIndex.foreach { case (acc, index) =>
|
|
|
+ println(s"Label $index: Accuracy = $acc")
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ def parseLine(line: String): (Array[(Int, Double)], Array[(Int, Int)]) = {
|
|
|
+ val sLine = line.split("\t")
|
|
|
+ val label = sLine(0).toInt
|
|
|
+ val scores = sLine(2).replace("[", "").replace("]", "").trim.split(",").map(_.toDouble)
|
|
|
+
|
|
|
+ // 找到最大值索引
|
|
|
+ val maxIndex = scores.zipWithIndex.maxBy(_._1)._2
|
|
|
+
|
|
|
+ // 生成八个数组
|
|
|
+ val aucs = scores.indices.map(i => (if (i == label) 1 else 0, scores(i))).toArray
|
|
|
+
|
|
|
+ // 新数组 (index, 是否是最大值)
|
|
|
+ val accuracyRate = scores.indices.map(i => (if (i == label) 1 else 0, if (i == maxIndex) 1 else 0)).toArray
|
|
|
+
|
|
|
+ (aucs, accuracyRate)
|
|
|
+ }
|
|
|
+
|
|
|
+ /** 计算每一列的 AUC */
|
|
|
+ def computeAUC(aucData: RDD[Array[(Int, Double)]]): Array[Double] = {
|
|
|
+ val numCols = 8
|
|
|
+ (0 until numCols).map { colIndex =>
|
|
|
+ val colData = aucData.map(arr => arr(colIndex)) // 取第 colIndex 列
|
|
|
+ val sortedData = colData.sortBy(_._2, ascending = false).collect()
|
|
|
+
|
|
|
+ val positive = sortedData.count(_._1 == 1).toDouble
|
|
|
+ val negative = sortedData.length - positive
|
|
|
+ if (positive == 0 || negative == 0) 0.5 // 避免无正样本或负样本导致 AUC 计算错误
|
|
|
+
|
|
|
+ var auc = 0.0
|
|
|
+ var rankSum = 0.0
|
|
|
+ var count = 0.0
|
|
|
+ sortedData.zipWithIndex.foreach { case ((label, _), rank) =>
|
|
|
+ if (label == 1) rankSum += rank + 1
|
|
|
+ }
|
|
|
+ auc = (rankSum - positive * (positive + 1) / 2) / (positive * negative)
|
|
|
+ auc
|
|
|
+ }.toArray
|
|
|
+ }
|
|
|
+
|
|
|
+ /** 计算全局 accuracy 和按 label 计算 accuracy */
|
|
|
+ def computeAccuracyRate(accData: RDD[Array[(Int, Int)]]): (Double, Array[Double]) = {
|
|
|
+ val numCols = 8
|
|
|
+
|
|
|
+ // 计算全局 accuracy (总计所有列)
|
|
|
+ val globalCorrect = accData.flatMap(_.filter { case (label, maxFlag) => label == 1 && maxFlag == 1 }).count().toDouble
|
|
|
+ val totalCount = accData.flatMap(_.map(_ => 1)).count().toDouble
|
|
|
+ val globalAccuracy = if (totalCount == 0) 0.0 else globalCorrect / totalCount
|
|
|
+
|
|
|
+ // 计算按 label 的 accuracy
|
|
|
+ val perLabelAccuracy = (0 until numCols).map { colIndex =>
|
|
|
+ val colData = accData.map(arr => arr(colIndex))
|
|
|
+ val truePositive = colData.filter { case (label, maxFlag) => label == 1 && maxFlag == 1 }.count().toDouble
|
|
|
+ val totalPositive = colData.filter(_._1 == 1).count().toDouble
|
|
|
+ if (totalPositive == 0) 0.0 else truePositive / totalPositive
|
|
|
+ }.toArray
|
|
|
+
|
|
|
+ (globalAccuracy, perLabelAccuracy)
|
|
|
+ }
|
|
|
+}
|