Pārlūkot izejas kodu

Update makedata_ad_33_bucketDataToHive_20250110: support negative sampling

StrayWarrior 2 mēneši atpakaļ
vecāks
revīzija
bc6dfb1047

+ 6 - 3
src/main/scala/com/aliyun/odps/spark/examples/makedata_ad/v20240718/makedata_ad_33_bucketDataToHive_20250110.scala

@@ -9,6 +9,7 @@ import org.apache.spark.sql.SparkSession
 
 import scala.collection.JavaConversions._
 import scala.io.Source
+import scala.util.Random
 
 /*
 
@@ -56,8 +57,9 @@ object makedata_ad_33_bucketDataToHive_20250110 {
     val filterNames = param.getOrElse("filterNames", "").split(",").filter(_.nonEmpty).toSet
     val whatLabel = param.getOrElse("whatLabel", "ad_is_conversion")
     val project = param.getOrElse("project", "loghubods")
-    val table = param.getOrElse("table", "ad_easyrec_train_data_v2")
+    val table = param.getOrElse("table", "ad_easyrec_train_data_v2_sampled")
     val partition = param.getOrElse("partition", "dt=20250208")
+    val negSampleRate = param.getOrElse("negSampleRate", "1").toDouble
 
     val dateRange = MyDateUtils.getDateRange(beginStr, endStr)
     for (date <- dateRange) {
@@ -100,8 +102,9 @@ object makedata_ad_33_bucketDataToHive_20250110 {
             }.toMap
             resultMap += ("has_conversion" -> label)
             resultMap += ("logkey" -> logKey)
-            resultMap
-        }
+            (label.toInt, resultMap, Random.nextDouble())
+        }.filter(r => r._3 < negSampleRate || r._1 > 0)
+        .map(r => r._2)
 
       // 4 hive
       odpsOps.saveToTable(project, table, partition, list, write, defaultCreate = true, overwrite = true)