|  | @@ -7,9 +7,10 @@ import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
 | 
	
		
			
				|  |  |  import org.apache.commons.lang.math.NumberUtils
 | 
	
		
			
				|  |  |  import org.apache.commons.lang3.StringUtils
 | 
	
		
			
				|  |  |  import org.apache.hadoop.io.compress.GzipCodec
 | 
	
		
			
				|  |  | -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
 | 
	
		
			
				|  |  |  import org.apache.spark.ml.feature.VectorAssembler
 | 
	
		
			
				|  |  | +import org.apache.spark.ml.linalg.Vector
 | 
	
		
			
				|  |  |  import org.apache.spark.rdd.RDD
 | 
	
		
			
				|  |  | +import org.apache.spark.sql.functions._
 | 
	
		
			
				|  |  |  import org.apache.spark.sql.types.DataTypes
 | 
	
		
			
				|  |  |  import org.apache.spark.sql.{Dataset, Row, SparkSession}
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -40,8 +41,8 @@ object recsys_01_ros_multi_class_xgb_train {
 | 
	
		
			
				|  |  |      val max_depth = param.getOrElse("max_depth", "5").toInt
 | 
	
		
			
				|  |  |      val num_round = param.getOrElse("num_round", "100").toInt
 | 
	
		
			
				|  |  |      val num_worker = param.getOrElse("num_worker", "20").toInt
 | 
	
		
			
				|  |  | -    val func_object = param.getOrElse("func_object", "multi:softmax")
 | 
	
		
			
				|  |  | -    val func_metric = param.getOrElse("func_metric", "auc")
 | 
	
		
			
				|  |  | +    val func_object = param.getOrElse("func_object", "multi:softprob")
 | 
	
		
			
				|  |  | +    val func_metric = param.getOrElse("func_metric", "mlogloss")
 | 
	
		
			
				|  |  |      val repartition = param.getOrElse("repartition", "20").toInt
 | 
	
		
			
				|  |  |      val numClass = param.getOrElse("numClass", "8").toInt
 | 
	
		
			
				|  |  |      val subsample = param.getOrElse("subsample", "0.95").toDouble
 | 
	
	
		
			
				|  | @@ -132,9 +133,9 @@ object recsys_01_ros_multi_class_xgb_train {
 | 
	
		
			
				|  |  |      val predictions = model.transform(testDataSetTrans)
 | 
	
		
			
				|  |  |      //     [label, features, probability, prediction, rawPrediction]
 | 
	
		
			
				|  |  |      println("zhangbo:columns:" + predictions.columns.mkString(","))
 | 
	
		
			
				|  |  | -    val saveData = predictions.select("label", "rawPrediction", "logKey").rdd
 | 
	
		
			
				|  |  | +    val saveData = predictions.select("label", "rawPrediction", "probability", "logKey").rdd
 | 
	
		
			
				|  |  |        .map(r => {
 | 
	
		
			
				|  |  | -        (r.get(0), r.get(1), r.get(2)).productIterator.mkString("\t")
 | 
	
		
			
				|  |  | +        (r.get(0), r.get(1), r.get(2), r.get(3)).productIterator.mkString("\t")
 | 
	
		
			
				|  |  |        })
 | 
	
		
			
				|  |  |      val hdfsPath = savePath
 | 
	
		
			
				|  |  |      if (hdfsPath.nonEmpty && hdfsPath.startsWith("/dw/recommend/model/")) {
 | 
	
	
		
			
				|  | @@ -146,28 +147,34 @@ object recsys_01_ros_multi_class_xgb_train {
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -//    val evaluator = new BinaryClassificationEvaluator()
 | 
	
		
			
				|  |  | -//      .setLabelCol("label")
 | 
	
		
			
				|  |  | -//      .setRawPredictionCol("probability")
 | 
	
		
			
				|  |  | -//      .setMetricName("areaUnderROC")
 | 
	
		
			
				|  |  | -//    val auc = evaluator.evaluate(predictions.select("label", "probability"))
 | 
	
		
			
				|  |  | -//    println("zhangbo:auc:" + auc)
 | 
	
		
			
				|  |  | +    // 计算 Multiclass Log Loss
 | 
	
		
			
				|  |  | +    val logLossDF = predictions.withColumn("log_loss",
 | 
	
		
			
				|  |  | +      udf((label: Double, probVec: Vector) => {
 | 
	
		
			
				|  |  | +        val prob = probVec(label.toInt)  // 获取正确类别的预测概率
 | 
	
		
			
				|  |  | +        -math.log(math.max(prob, 1e-15))  // 防止 log(0) 问题
 | 
	
		
			
				|  |  | +      }).apply($"label", $"probability")
 | 
	
		
			
				|  |  | +    )
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    // 计算平均 log loss
 | 
	
		
			
				|  |  | +    val mlogloss = logLossDF.agg(avg($"log_loss")).as[Double].collect()(0)
 | 
	
		
			
				|  |  | +    println(s"Multiclass Log Loss: $mlogloss")
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      // 统计分cid的分数
 | 
	
		
			
				|  |  | -//    sc.textFile(hdfsPath).map(r => {
 | 
	
		
			
				|  |  | -//      val rList = r.split("\t")
 | 
	
		
			
				|  |  | -//      val vid = JSON.parseObject(rList(3)).getString("vid")
 | 
	
		
			
				|  |  | -//      val label = rList(0).toDouble
 | 
	
		
			
				|  |  | -//      val score = RosUtil.multiClassModelScore(rList(2), predictLabelList_br.value)
 | 
	
		
			
				|  |  | -//
 | 
	
		
			
				|  |  | -//      ((vid, label), (1, score))
 | 
	
		
			
				|  |  | -//    }).reduceByKey {
 | 
	
		
			
				|  |  | -//      case ((c1, s1), (c2, s2)) =>
 | 
	
		
			
				|  |  | -//        (c1 + c2, (s1 + s2))
 | 
	
		
			
				|  |  | -//    }.map {
 | 
	
		
			
				|  |  | -//      case ((vid, label), (count, sumScore)) =>
 | 
	
		
			
				|  |  | -//        (vid, label, count, sumScore, sumScore / count)
 | 
	
		
			
				|  |  | -//    }.collect().sortBy(_._1).map(_.productIterator.mkString("\t")).foreach(println)
 | 
	
		
			
				|  |  | +    sc.textFile(hdfsPath).map(r => {
 | 
	
		
			
				|  |  | +      val rList = r.split("\t")
 | 
	
		
			
				|  |  | +      val vid = JSON.parseObject(rList(3)).getString("vid")
 | 
	
		
			
				|  |  | +      val label = rList(0).toDouble
 | 
	
		
			
				|  |  | +      val score = RosUtil.multiClassModelScore(rList(2), predictLabelList_br.value)
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +      ((vid, label), (1, score))
 | 
	
		
			
				|  |  | +    }).reduceByKey {
 | 
	
		
			
				|  |  | +      case ((c1, s1), (c2, s2)) =>
 | 
	
		
			
				|  |  | +        (c1 + c2, (s1 + s2))
 | 
	
		
			
				|  |  | +    }.map {
 | 
	
		
			
				|  |  | +      case ((vid, label), (count, sumScore)) =>
 | 
	
		
			
				|  |  | +        (vid, label, count, sumScore, sumScore / count)
 | 
	
		
			
				|  |  | +    }.collect().sortBy(_._1).map(_.productIterator.mkString("\t")).foreach(println)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |    }
 | 
	
		
			
				|  |  |  
 |