|
@@ -7,9 +7,10 @@ import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
|
|
|
import org.apache.commons.lang.math.NumberUtils
|
|
|
import org.apache.commons.lang3.StringUtils
|
|
|
import org.apache.hadoop.io.compress.GzipCodec
|
|
|
-import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
|
|
|
import org.apache.spark.ml.feature.VectorAssembler
|
|
|
+import org.apache.spark.ml.linalg.Vector
|
|
|
import org.apache.spark.rdd.RDD
|
|
|
+import org.apache.spark.sql.functions._
|
|
|
import org.apache.spark.sql.types.DataTypes
|
|
|
import org.apache.spark.sql.{Dataset, Row, SparkSession}
|
|
|
|
|
@@ -40,8 +41,8 @@ object recsys_01_ros_multi_class_xgb_train {
|
|
|
val max_depth = param.getOrElse("max_depth", "5").toInt
|
|
|
val num_round = param.getOrElse("num_round", "100").toInt
|
|
|
val num_worker = param.getOrElse("num_worker", "20").toInt
|
|
|
- val func_object = param.getOrElse("func_object", "multi:softmax")
|
|
|
- val func_metric = param.getOrElse("func_metric", "auc")
|
|
|
+ val func_object = param.getOrElse("func_object", "multi:softprob")
|
|
|
+ val func_metric = param.getOrElse("func_metric", "mlogloss")
|
|
|
val repartition = param.getOrElse("repartition", "20").toInt
|
|
|
val numClass = param.getOrElse("numClass", "8").toInt
|
|
|
val subsample = param.getOrElse("subsample", "0.95").toDouble
|
|
@@ -132,9 +133,9 @@ object recsys_01_ros_multi_class_xgb_train {
|
|
|
val predictions = model.transform(testDataSetTrans)
|
|
|
// [label, features, probability, prediction, rawPrediction]
|
|
|
println("zhangbo:columns:" + predictions.columns.mkString(","))
|
|
|
- val saveData = predictions.select("label", "rawPrediction", "logKey").rdd
|
|
|
+ val saveData = predictions.select("label", "rawPrediction", "probability", "logKey").rdd
|
|
|
.map(r => {
|
|
|
- (r.get(0), r.get(1), r.get(2)).productIterator.mkString("\t")
|
|
|
+ (r.get(0), r.get(1), r.get(2), r.get(3)).productIterator.mkString("\t")
|
|
|
})
|
|
|
val hdfsPath = savePath
|
|
|
if (hdfsPath.nonEmpty && hdfsPath.startsWith("/dw/recommend/model/")) {
|
|
@@ -146,28 +147,34 @@ object recsys_01_ros_multi_class_xgb_train {
|
|
|
}
|
|
|
|
|
|
|
|
|
-// val evaluator = new BinaryClassificationEvaluator()
|
|
|
-// .setLabelCol("label")
|
|
|
-// .setRawPredictionCol("probability")
|
|
|
-// .setMetricName("areaUnderROC")
|
|
|
-// val auc = evaluator.evaluate(predictions.select("label", "probability"))
|
|
|
-// println("zhangbo:auc:" + auc)
|
|
|
+ // 计算 Multiclass Log Loss
|
|
|
+ val logLossDF = predictions.withColumn("log_loss",
|
|
|
+ udf((label: Double, probVec: Vector) => {
|
|
|
+ val prob = probVec(label.toInt) // 获取正确类别的预测概率
|
|
|
+ -math.log(math.max(prob, 1e-15)) // 防止 log(0) 问题
|
|
|
+ }).apply($"label", $"probability")
|
|
|
+ )
|
|
|
+
|
|
|
+ // 计算平均 log loss
|
|
|
+ val mlogloss = logLossDF.agg(avg($"log_loss")).as[Double].collect()(0)
|
|
|
+ println(s"Multiclass Log Loss: $mlogloss")
|
|
|
+
|
|
|
|
|
|
// 统计分cid的分数
|
|
|
-// sc.textFile(hdfsPath).map(r => {
|
|
|
-// val rList = r.split("\t")
|
|
|
-// val vid = JSON.parseObject(rList(3)).getString("vid")
|
|
|
-// val label = rList(0).toDouble
|
|
|
-// val score = RosUtil.multiClassModelScore(rList(2), predictLabelList_br.value)
|
|
|
-//
|
|
|
-// ((vid, label), (1, score))
|
|
|
-// }).reduceByKey {
|
|
|
-// case ((c1, s1), (c2, s2)) =>
|
|
|
-// (c1 + c2, (s1 + s2))
|
|
|
-// }.map {
|
|
|
-// case ((vid, label), (count, sumScore)) =>
|
|
|
-// (vid, label, count, sumScore, sumScore / count)
|
|
|
-// }.collect().sortBy(_._1).map(_.productIterator.mkString("\t")).foreach(println)
|
|
|
+ sc.textFile(hdfsPath).map(r => {
|
|
|
+ val rList = r.split("\t")
|
|
|
+ val vid = JSON.parseObject(rList(3)).getString("vid")
|
|
|
+ val label = rList(0).toDouble
|
|
|
+ val score = RosUtil.multiClassModelScore(rList(2), predictLabelList_br.value)
|
|
|
+
|
|
|
+ ((vid, label), (1, score))
|
|
|
+ }).reduceByKey {
|
|
|
+ case ((c1, s1), (c2, s2)) =>
|
|
|
+ (c1 + c2, (s1 + s2))
|
|
|
+ }.map {
|
|
|
+ case ((vid, label), (count, sumScore)) =>
|
|
|
+ (vid, label, count, sumScore, sumScore / count)
|
|
|
+ }.collect().sortBy(_._1).map(_.productIterator.mkString("\t")).foreach(println)
|
|
|
|
|
|
}
|
|
|
|