Browse Source

feat:修改XGB训练脚本

zhaohaipeng 2 months ago
parent
commit
83e6cccbf1

+ 15 - 12
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/recsys_01_xgb_train.scala

@@ -1,5 +1,6 @@
 package com.tzld.piaoquan.recommend.model
 
+import com.alibaba.fastjson.JSON
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
 import org.apache.commons.lang.math.NumberUtils
 import org.apache.commons.lang3.StringUtils
@@ -7,12 +8,10 @@ import org.apache.hadoop.io.compress.GzipCodec
 import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
 import org.apache.spark.ml.feature.VectorAssembler
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.types.{DataTypes, StructField}
-import org.apache.spark.sql.{Dataset, Row, RowFactory, SparkSession}
+import org.apache.spark.sql.types.DataTypes
+import org.apache.spark.sql.{Dataset, Row, SparkSession}
 
-import scala.collection.JavaConversions._
 import java.util
-import scala.collection.mutable.ArrayBuffer
 import scala.io.Source
 
 object recsys_01_xgb_train {
@@ -68,6 +67,9 @@ object recsys_01_xgb_train {
 
     ) ++ features.map(f => DataTypes.createStructField(f, DataTypes.DoubleType, true))
 
+    fields = fields ++ Array(
+      DataTypes.createStructField("logKey", DataTypes.StringType, true)
+    )
     val schema = DataTypes.createStructType(fields)
     val trainDataSet: Dataset[Row] = spark.createDataFrame(trainData, schema)
     val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
@@ -101,13 +103,13 @@ object recsys_01_xgb_train {
       features
     )
     val testDataSet = spark.createDataFrame(testData, schema)
-    val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features", "label")
+    val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features", "label", "logKey")
     val predictions = model.transform(testDataSetTrans)
     //     [label, features, probability, prediction, rawPrediction]
     println("zhangbo:columns:" + predictions.columns.mkString(","))
-    val saveData = predictions.select("label", "rawPrediction", "probability").rdd
+    val saveData = predictions.select("label", "rawPrediction", "probability", "logKey").rdd
       .map(r => {
-        (r.get(0), r.get(1), r.get(2)).productIterator.mkString("\t")
+        (r.get(0), r.get(1), r.get(2), r.get(3)).productIterator.mkString("\t")
       })
     val hdfsPath = savePath
     if (hdfsPath.nonEmpty && hdfsPath.startsWith("/dw/recommend/model/")) {
@@ -129,16 +131,16 @@ object recsys_01_xgb_train {
     // 统计分cid的分数
     sc.textFile(hdfsPath).map(r => {
       val rList = r.split("\t")
-      val cid = rList(3)
+      val vid = JSON.parseObject(rList(3)).getString("vid")
       val score = rList(2).replace("[", "").replace("]", "")
         .split(",")(1).toDouble
       val label = rList(0).toDouble
-      (cid, (1, label, score))
+      (vid, (1, label, score))
     }).reduceByKey {
       case (a, b) => (a._1 + b._1, a._2 + b._2, a._3 + b._3)
     }.map {
-      case (cid, (all, zheng, scores)) =>
-        (cid, all, zheng, scores, zheng / all, scores / all)
+      case (vid, (all, zheng, scores)) =>
+        (vid, all, zheng, scores, zheng / all, scores / all)
     }.collect().sortBy(_._1).map(_.productIterator.mkString("\t")).foreach(println)
 
   }
@@ -155,11 +157,12 @@ object recsys_01_xgb_train {
         map.put(fv(0), NumberUtils.toDouble(fv(1), 0.0))
       }
 
-      val v: Array[Any] = new Array[Any](features.length + 1)
+      val v: Array[Any] = new Array[Any](features.length + 2)
       v(0) = label
       for (i <- 0 until features.length) {
         v(i + 1) = map.getOrDefault(features(i), 0.0d)
       }
+      v(features.length + 1) = line(0)
       Row(v: _*)
     })
   }