Explorar o código

Update train_01_xgb_ad_20250104: support negative sample calibration

StrayWarrior hai 3 meses
pai
achega
9835503261

+ 14 - 1
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/train_01_xgb_ad_20250104.scala

@@ -7,6 +7,7 @@ import org.apache.hadoop.io.compress.GzipCodec
 import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
 import org.apache.spark.ml.feature.VectorAssembler
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.functions.udf
 import org.apache.spark.sql.types.DataTypes
 import org.apache.spark.sql.{Dataset, Row, SparkSession}
 
@@ -104,6 +105,12 @@ object train_01_xgb_ad_20250104 {
       model = XGBoostClassificationModel.load(modelPath)
     }
 
+    def calibrateUDF = udf((probability: org.apache.spark.ml.linalg.Vector) => {
+      val positiveProb = probability.toArray(1)
+      val calibratedProb = positiveProb / (positiveProb + (1 - positiveProb) / negSampleRate)
+      val newProb = Array(1 - calibratedProb, calibratedProb)
+      org.apache.spark.ml.linalg.Vectors.dense(newProb)
+    })
 
     val testData = createData4Ad(
       sc.textFile(testPath),
@@ -111,8 +118,14 @@ object train_01_xgb_ad_20250104 {
     )
     val testDataSet = spark.createDataFrame(testData, schema)
     val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features","label", "logKey")
-    val predictions = model.transform(testDataSetTrans)
+    var predictions = model.transform(testDataSetTrans)
     println("columns:" + predictions.columns.mkString(","))
+    // calibrate the prediction for negative sampling
+    if (negSampleRate < 1) {
+      println("calibrate the prediction for negative sampling")
+      predictions = predictions.withColumn("probability", calibrateUDF(predictions("probability")))
+    }
+
     val saveData = predictions.select("label", "rawPrediction", "probability", "logKey").rdd
       .map(r =>{
         (r.get(0), r.get(1), r.get(2), r.get(3)).productIterator.mkString("\t")