瀏覽代碼

feat:ro二分类评测脚本开发

zhaohaipeng 1 月之前
父節點
當前提交
026ce84cb5

+ 1 - 1
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/recsys_01_ros_binary_weight_xgb_train.scala

@@ -166,7 +166,7 @@ object recsys_01_ros_binary_weight_xgb_train {
       val logKey = line(0)
 
       val logJson = JSON.parseObject(logKey)
-      val weight = logJson.getDouble(weightField);
+      val weight = NumberUtils.toDouble(StringUtils.defaultIfBlank(logJson.getString(weightField), "0"));
 
       val label: Int = NumberUtils.toInt(line(1))
       val map: util.Map[String, Double] = new util.HashMap[String, Double]

+ 112 - 0
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/recsys_02_ros_model_predict.scala

@@ -0,0 +1,112 @@
+package com.tzld.piaoquan.recommend.model
+
+import com.tzld.piaoquan.recommend.utils.{FileUtils, MyHdfsUtils, ParamUtils}
+import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel
+import org.apache.commons.lang.math.NumberUtils
+import org.apache.commons.lang3.StringUtils
+import org.apache.hadoop.io.compress.GzipCodec
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
+import org.apache.spark.ml.feature.VectorAssembler
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.types.DataTypes
+import org.apache.spark.sql.{Row, SparkSession}
+
+import java.time.LocalDateTime
+import java.time.format.DateTimeFormatter
+import java.util
+
+object recsys_02_ros_model_predict {
+
+  def main(args: Array[String]): Unit = {
+    val dt = DateTimeFormatter.ofPattern("yyyyMMddHHmm").format(LocalDateTime.now())
+
+    val spark = SparkSession
+      .builder()
+      .appName(this.getClass.getName + ": " + dt)
+      .getOrCreate()
+    val sc = spark.sparkContext
+
+    val param = ParamUtils.parseArgs(args)
+    val featureFile = param.getOrElse("featureFile", "20240703_ad_feature_name.txt")
+    val testPath = param.getOrElse("testPath", "")
+    val savePath = param.getOrElse("savePath", "/dw/recommend/model/34_ad_predict_data/")
+    val featureFilter = param.getOrElse("featureFilter", "").split(",").filter(_.nonEmpty).toList
+
+    val repartition = param.getOrElse("repartition", "20").toInt
+    val modelPath = param.getOrElse("modelPath", "/dw/recommend/model/35_ad_model/model_xgb")
+
+    val loader = getClass.getClassLoader
+    val resourceUrl = loader.getResource(featureFile)
+    val fileContent = FileUtils.readFile(resourceUrl)
+    println(fileContent)
+
+    val features = fileContent.split("\n")
+      .map(r => r.replace(" ", "").replaceAll("\n", ""))
+      .filter(r => r.nonEmpty || !featureFilter.contains(r))
+    println("features.size=" + features.length)
+
+    var fields = Array(
+      DataTypes.createStructField("label", DataTypes.IntegerType, true)
+    ) ++ features.map(f => DataTypes.createStructField(f, DataTypes.DoubleType, true))
+
+    fields = fields ++ Array(
+      DataTypes.createStructField("logKey", DataTypes.StringType, true)
+    )
+
+    val schema = DataTypes.createStructType(fields)
+    val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
+
+    val testData = createData4Ad(sc.textFile(testPath), features)
+
+    val model = XGBoostClassificationModel.load(modelPath)
+    model.setMissing(0.0f).setFeaturesCol("features")
+
+    val testDataSet = spark.createDataFrame(testData, schema)
+    val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features", "label", "logKey")
+    val predictions = model.transform(testDataSetTrans)
+
+    val saveData = predictions.select("label", "rawPrediction", "probability", "logKey").rdd
+      .map(r => {
+        (r.get(0), r.get(1), r.get(2), r.get(3)).productIterator.mkString("\t")
+      })
+    val hdfsPath = savePath
+    if (hdfsPath.nonEmpty && hdfsPath.startsWith("/dw/recommend/model/")) {
+      println("删除路径并开始数据写入:" + hdfsPath)
+      MyHdfsUtils.delete_hdfs_path(hdfsPath)
+      saveData.repartition(repartition).saveAsTextFile(hdfsPath, classOf[GzipCodec])
+    } else {
+      println("路径不合法,无法写入:" + hdfsPath)
+    }
+
+
+    val evaluator = new BinaryClassificationEvaluator()
+      .setLabelCol("label")
+      .setRawPredictionCol("probability")
+      .setMetricName("areaUnderROC")
+    val auc = evaluator.evaluate(predictions.select("label", "probability"))
+    println("zhangbo:auc:" + auc)
+
+  }
+
+  def createData4Ad(data: RDD[String], features: Array[String]): RDD[Row] = {
+    data.map(r => {
+      val line: Array[String] = StringUtils.split(r, '\t')
+      val logKey = line(0)
+
+      val label: Int = NumberUtils.toInt(line(1))
+      val map: util.Map[String, Double] = new util.HashMap[String, Double]
+      for (i <- 2 until line.length) {
+        val fv: Array[String] = StringUtils.split(line(i), ':')
+        map.put(fv(0), NumberUtils.toDouble(fv(1), 0.0))
+      }
+
+      val v: Array[Any] = new Array[Any](features.length + 2)
+      v(0) = label
+      for (i <- features.indices) {
+        v(i + 1) = map.getOrDefault(features(i), 0.0d)
+      }
+      v(features.length + 1) = logKey
+      Row(v: _*)
+    })
+  }
+}

+ 0 - 104
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/recsys_02_ros_multi_class_predict_analyse.scala

@@ -1,104 +0,0 @@
-package com.tzld.piaoquan.recommend.model
-
-import com.tzld.piaoquan.recommend.utils.ParamUtils
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SparkSession
-
-import java.time.LocalDateTime
-import java.time.format.DateTimeFormatter
-
-object recsys_02_ros_multi_class_predict_analyse {
-  def main(args: Array[String]): Unit = {
-    val dt = DateTimeFormatter.ofPattern("yyyyMMddHHmm").format(LocalDateTime.now())
-
-    val spark = SparkSession
-      .builder()
-      .appName(this.getClass.getName + ": " + dt)
-      .getOrCreate()
-    val sc = spark.sparkContext
-
-    val param = ParamUtils.parseArgs(args)
-    val readPath = param.getOrElse("readPath", "/dw/recommend/model/44_recsys_ros_predict/")
-
-    val data = sc.textFile(readPath)
-      .map(parseLine)
-      .cache()
-
-    // 计算 AUC
-    val aucScores = computeAUC(data.map(_._1))
-
-    // 计算 accuracyRate 指标
-    val (globalAcc, perLabelAcc) = computeAccuracyRate(data.map(_._2))
-
-    // 打印结果
-    println("AUC Scores:")
-    aucScores.zipWithIndex.foreach { case (auc, index) => println(s"Label $index: AUC = $auc") }
-
-    println(s"\nGlobal Accuracy: $globalAcc")
-
-    println("\nPer Label Accuracy:")
-    perLabelAcc.zipWithIndex.foreach { case (acc, index) =>
-      println(s"Label $index: Accuracy = $acc")
-    }
-
-  }
-
-  def parseLine(line: String): (Array[(Int, Double)], Array[(Int, Int)]) = {
-    val sLine = line.split("\t")
-    val label = sLine(0).toInt
-    val scores = sLine(2).replace("[", "").replace("]", "").trim.split(",").map(_.toDouble)
-
-    // 找到最大值索引
-    val maxIndex = scores.zipWithIndex.maxBy(_._1)._2
-
-    // 生成八个数组
-    val aucs = scores.indices.map(i => (if (i == label) 1 else 0, scores(i))).toArray
-
-    // 新数组 (index, 是否是最大值)
-    val accuracyRate = scores.indices.map(i => (if (i == label) 1 else 0, if (i == maxIndex) 1 else 0)).toArray
-
-    (aucs, accuracyRate)
-  }
-
-  /** 计算每一列的 AUC */
-  def computeAUC(aucData: RDD[Array[(Int, Double)]]): Array[Double] = {
-    val numCols = 8
-    (0 until numCols).map { colIndex =>
-      val colData = aucData.map(arr => arr(colIndex)) // 取第 colIndex 列
-      val sortedData = colData.sortBy(_._2, ascending = false).collect()
-
-      val positive = sortedData.count(_._1 == 1).toDouble
-      val negative = sortedData.length - positive
-      if (positive == 0 || negative == 0) 0.5 // 避免无正样本或负样本导致 AUC 计算错误
-
-      var auc = 0.0
-      var rankSum = 0.0
-      var count = 0.0
-      sortedData.zipWithIndex.foreach { case ((label, _), rank) =>
-        if (label == 1) rankSum += rank + 1
-      }
-      auc = (rankSum - positive * (positive + 1) / 2) / (positive * negative)
-      auc
-    }.toArray
-  }
-
-  /** 计算全局 accuracy 和按 label 计算 accuracy */
-  def computeAccuracyRate(accData: RDD[Array[(Int, Int)]]): (Double, Array[Double]) = {
-    val numCols = 8
-
-    // 计算全局 accuracy (总计所有列)
-    val globalCorrect = accData.flatMap(_.filter { case (label, maxFlag) => label == 1 && maxFlag == 1 }).count().toDouble
-    val totalCount = accData.flatMap(_.map(_ => 1)).count().toDouble
-    val globalAccuracy = if (totalCount == 0) 0.0 else globalCorrect / totalCount
-
-    // 计算按 label 的 accuracy
-    val perLabelAccuracy = (0 until numCols).map { colIndex =>
-      val colData = accData.map(arr => arr(colIndex))
-      val truePositive = colData.filter { case (label, maxFlag) => label == 1 && maxFlag == 1 }.count().toDouble
-      val totalPositive = colData.filter(_._1 == 1).count().toDouble
-      if (totalPositive == 0) 0.0 else truePositive / totalPositive
-    }.toArray
-
-    (globalAccuracy, perLabelAccuracy)
-  }
-}

+ 21 - 0
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/utils/FileUtils.scala

@@ -0,0 +1,21 @@
+package com.tzld.piaoquan.recommend.utils
+
+import java.net.URL
+import scala.io.Source
+
+object FileUtils {
+  def readFile(path: URL): String = {
+    var source: Option[Source] = None
+    try {
+      source = Some(Source.fromURL(path))
+      val content = source.get.getLines().mkString("\n")
+      content
+    } catch {
+      case e: Exception =>
+        println(s"读取文件: ${path}, 发生未知错误: ${e.getMessage}")
+        ""
+    } finally {
+      source.foreach(_.close())
+    }
+  }
+}