Sfoglia il codice sorgente

rov训练增加负采样

jch 4 mesi fa
parent
commit
74c3f3f689

+ 26 - 17
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/train_recsys_61_xgb_rov_20241209.scala

@@ -12,6 +12,7 @@ import org.apache.spark.sql.{Dataset, Row, SparkSession}
 
 import java.util
 import scala.io.Source
+import scala.util.Random
 
 object train_recsys_61_xgb_rov_20241209 {
   def main(args: Array[String]): Unit = {
@@ -35,6 +36,7 @@ object train_recsys_61_xgb_rov_20241209 {
     val func_object = param.getOrElse("func_object", "binary:logistic")
     val func_metric = param.getOrElse("func_metric", "auc")
     val repartition = param.getOrElse("repartition", "20").toInt
+    val negRate = param.getOrElse("negRate", "1.0").toDouble
     val modelPath = param.getOrElse("modelPath", "/dw/recommend/model/61_recsys_rov_model/model_xgb")
     val modelFile = param.getOrElse("modelFile", "model_xgb_for_recsys_rov.tar.gz")
 
@@ -56,6 +58,7 @@ object train_recsys_61_xgb_rov_20241209 {
     println("features.size=" + features.length)
 
     val trainData = createData(
+      negRate,
       sc.textFile(trainPath),
       features
     )
@@ -100,6 +103,7 @@ object train_recsys_61_xgb_rov_20241209 {
 
     if (testPath.nonEmpty) {
       val testData = createData(
+        1.0,
         sc.textFile(testPath),
         features
       )
@@ -130,23 +134,28 @@ object train_recsys_61_xgb_rov_20241209 {
     }
   }
 
-  def createData(data: RDD[String], features: Array[String]): RDD[Row] = {
-    data.map(r => {
-      val line: Array[String] = StringUtils.split(r, '\t')
-      // val logKey = line(0)
-      val label: Int = NumberUtils.toInt(line(1))
-      val map: util.Map[String, Double] = new util.HashMap[String, Double]
-      for (i <- 2 until line.length) {
-        val fv: Array[String] = StringUtils.split(line(i), ':')
-        map.put(fv(0), NumberUtils.toDouble(fv(1), 0.0))
-      }
+  def createData(negRate: Double, data: RDD[String], features: Array[String]): RDD[Row] = {
+    data.filter(r => {
+        val line: Array[String] = StringUtils.split(r, '\t')
+        val label: Int = NumberUtils.toInt(line(1))
+        label > 0 || new Random().nextDouble() <= negRate
+      })
+      .map(r => {
+        val line: Array[String] = StringUtils.split(r, '\t')
+        // val logKey = line(0)
+        val label: Int = NumberUtils.toInt(line(1))
+        val map: util.Map[String, Double] = new util.HashMap[String, Double]
+        for (i <- 2 until line.length) {
+          val fv: Array[String] = StringUtils.split(line(i), ':')
+          map.put(fv(0), NumberUtils.toDouble(fv(1), 0.0))
+        }
 
-      val v: Array[Any] = new Array[Any](features.length + 1)
-      v(0) = label
-      for (i <- 0 until features.length) {
-        v(i + 1) = map.getOrDefault(features(i), 0.0d)
-      }
-      Row(v: _*)
-    })
+        val v: Array[Any] = new Array[Any](features.length + 1)
+        v(0) = label
+        for (i <- 0 until features.length) {
+          v(i + 1) = map.getOrDefault(features(i), 0.0d)
+        }
+        Row(v: _*)
+      })
   }
 }