Browse Source

feat:添加ros特征文件

zhaohaipeng 1 month ago
parent
commit
462933e9fd

+ 32 - 25
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/recsys_01_ros_multi_class_xgb_train.scala

@@ -7,9 +7,10 @@ import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
 import org.apache.commons.lang.math.NumberUtils
 import org.apache.commons.lang3.StringUtils
 import org.apache.hadoop.io.compress.GzipCodec
-import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
 import org.apache.spark.ml.feature.VectorAssembler
+import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.DataTypes
 import org.apache.spark.sql.{Dataset, Row, SparkSession}
 
@@ -40,8 +41,8 @@ object recsys_01_ros_multi_class_xgb_train {
     val max_depth = param.getOrElse("max_depth", "5").toInt
     val num_round = param.getOrElse("num_round", "100").toInt
     val num_worker = param.getOrElse("num_worker", "20").toInt
-    val func_object = param.getOrElse("func_object", "multi:softmax")
-    val func_metric = param.getOrElse("func_metric", "auc")
+    val func_object = param.getOrElse("func_object", "multi:softprob")
+    val func_metric = param.getOrElse("func_metric", "mlogloss")
     val repartition = param.getOrElse("repartition", "20").toInt
     val numClass = param.getOrElse("numClass", "8").toInt
     val subsample = param.getOrElse("subsample", "0.95").toDouble
@@ -132,9 +133,9 @@ object recsys_01_ros_multi_class_xgb_train {
     val predictions = model.transform(testDataSetTrans)
     //     [label, features, probability, prediction, rawPrediction]
     println("zhangbo:columns:" + predictions.columns.mkString(","))
-    val saveData = predictions.select("label", "rawPrediction", "logKey").rdd
+    val saveData = predictions.select("label", "rawPrediction", "probability", "logKey").rdd
       .map(r => {
-        (r.get(0), r.get(1), r.get(2)).productIterator.mkString("\t")
+        (r.get(0), r.get(1), r.get(2), r.get(3)).productIterator.mkString("\t")
       })
     val hdfsPath = savePath
     if (hdfsPath.nonEmpty && hdfsPath.startsWith("/dw/recommend/model/")) {
@@ -146,28 +147,34 @@ object recsys_01_ros_multi_class_xgb_train {
     }
 
 
-//    val evaluator = new BinaryClassificationEvaluator()
-//      .setLabelCol("label")
-//      .setRawPredictionCol("probability")
-//      .setMetricName("areaUnderROC")
-//    val auc = evaluator.evaluate(predictions.select("label", "probability"))
-//    println("zhangbo:auc:" + auc)
+    // 计算 Multiclass Log Loss
+    val logLossDF = predictions.withColumn("log_loss",
+      udf((label: Double, probVec: Vector) => {
+        val prob = probVec(label.toInt)  // 获取正确类别的预测概率
+        -math.log(math.max(prob, 1e-15))  // 防止 log(0) 问题
+      }).apply($"label", $"probability")
+    )
+
+    // 计算平均 log loss
+    val mlogloss = logLossDF.agg(avg($"log_loss")).as[Double].collect()(0)
+    println(s"Multiclass Log Loss: $mlogloss")
+
 
     // 统计分cid的分数
-//    sc.textFile(hdfsPath).map(r => {
-//      val rList = r.split("\t")
-//      val vid = JSON.parseObject(rList(3)).getString("vid")
-//      val label = rList(0).toDouble
-//      val score = RosUtil.multiClassModelScore(rList(2), predictLabelList_br.value)
-//
-//      ((vid, label), (1, score))
-//    }).reduceByKey {
-//      case ((c1, s1), (c2, s2)) =>
-//        (c1 + c2, (s1 + s2))
-//    }.map {
-//      case ((vid, label), (count, sumScore)) =>
-//        (vid, label, count, sumScore, sumScore / count)
-//    }.collect().sortBy(_._1).map(_.productIterator.mkString("\t")).foreach(println)
+    sc.textFile(hdfsPath).map(r => {
+      val rList = r.split("\t")
+      val vid = JSON.parseObject(rList(3)).getString("vid")
+      val label = rList(0).toDouble
+      val score = RosUtil.multiClassModelScore(rList(2), predictLabelList_br.value)
+
+      ((vid, label), (1, score))
+    }).reduceByKey {
+      case ((c1, s1), (c2, s2)) =>
+        (c1 + c2, (s1 + s2))
+    }.map {
+      case ((vid, label), (count, sumScore)) =>
+        (vid, label, count, sumScore, sumScore / count)
+    }.collect().sortBy(_._1).map(_.productIterator.mkString("\t")).foreach(println)
 
   }