Browse Source

Update pred_01_xgb_ad_hdfsfile_20240813: support negative sample calibration

StrayWarrior 3 tháng trước cách đây
mục cha
commit
588af303e2

+ 15 - 3
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/pred_01_xgb_ad_hdfsfile_20240813.scala

@@ -10,6 +10,7 @@ import org.apache.spark.sql.types.DataTypes
 import org.apache.spark.sql.{Dataset, Row, SparkSession}
 import com.alibaba.fastjson.{JSON, JSONArray}
 import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
+import org.apache.spark.sql.functions.udf
 
 import java.util
 import scala.collection.mutable.ArrayBuffer
@@ -28,6 +29,7 @@ object pred_01_xgb_ad_hdfsfile_20240813{
     val testPath = param.getOrElse("testPath", "")
     val savePath = param.getOrElse("savePath", "/dw/recommend/model/34_ad_predict_data/")
     val featureFilter = param.getOrElse("featureFilter", "XXXXXX").split(",")
+    val negSampleRate = param.getOrElse("negSampleRate", "1").toDouble
 
     val repartition = param.getOrElse("repartition", "20").toInt
     val modelPath = param.getOrElse("modelPath", "/dw/recommend/model/35_ad_model/model_xgb")
@@ -62,16 +64,26 @@ object pred_01_xgb_ad_hdfsfile_20240813{
     val model = XGBoostClassificationModel.load(modelPath)
     model.setMissing(0.0f).setFeaturesCol("features")
 
-
-
     val testData = createData4Ad(
       sc.textFile(testPath),
       features
     )
 
+    def calibrateUDF = udf((probability: org.apache.spark.ml.linalg.Vector) => {
+      val positiveProb = probability.toArray(1)
+      val calibratedProb = positiveProb / (positiveProb + (1 - positiveProb) / negSampleRate)
+      val newProb = Array(1 - calibratedProb, calibratedProb)
+      org.apache.spark.ml.linalg.Vectors.dense(newProb)
+    })
+
     val testDataSet = spark.createDataFrame(testData, schema)
     val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features", "label", "logKey")
-    val predictions = model.transform(testDataSetTrans)
+    var predictions = model.transform(testDataSetTrans)
+    // calibrate the prediction for negative sampling
+    if (negSampleRate < 1) {
+      println("calibrate the prediction for negative sampling")
+      predictions = predictions.withColumn("probability", calibrateUDF(predictions("probability")))
+    }
 
     val saveData = predictions.select("label", "rawPrediction", "probability", "logKey").rdd
       .map(r => {