often 5 kuukautta sitten
vanhempi
commit
f8b0d7e904

+ 51 - 49
src/main/scala/com/aliyun/odps/spark/examples/makedata_recsys/video_dssm_sampler.scala

@@ -129,66 +129,72 @@ object video_dssm_sampler {
         partition = s"dt=20241124,hh=08",
         transfer = funcPositive,
         numPartition = CONFIG("shuffle.partitions").toInt
-      ).sample(false, 0.001) // 随机抽样千分之一的数据
-       .persist(StorageLevel.MEMORY_AND_DISK_SER)
+      ).persist(StorageLevel.MEMORY_AND_DISK_SER)
       println("开始执行partiton:" + partition)
 
       val positivePairs = spark.createDataFrame(rdd, schema).persist(StorageLevel.MEMORY_AND_DISK_SER)
       stats.positiveSamplesCount = positivePairs.count()
       logger.info(s"start read vid list for date $dt")
 
-      val numberedPos = positivePairs
-        .withColumn("pos_id", row_number().over(Window.orderBy("vid_left", "vid_right")))
-      // 2. 给候选视频表按100条分组编号
-      val vidsRDD = odpsOps.readTable(
-        project = "loghubods",
-        table = "t_vid_tag_feature",
-        partition = s"dt='$dt'",
-        transfer = funcVids,
-        numPartition = CONFIG("shuffle.partitions").toInt
-      )
+      // 2. 获取所有可用的vid列表
+      val allVids = time({
+        odpsOps.readTable(
+          project = "loghubods",
+          table = "t_vid_tag_feature",
+          partition = s"dt='$dt'",
+          transfer = funcVids,
+          numPartition = CONFIG("shuffle.partitions").toInt
+        ).collect().toSet
+      }, "Loading all vids")
+
+
+      val allVidsBroadcast = spark.sparkContext.broadcast(allVids)
+
+      // 4. 生成负样本
+      val negativeSamplesRDD = time({
+        positivePairs.rdd
+          .mapPartitions { iter =>
+            // 每次处理一批数据,而不是单条处理
+            val batchSize = 100
+            val localVids = allVidsBroadcast.value.toArray // 只转换一次为数组
+            val random = new Random()
+
+            iter.grouped(batchSize).flatMap { batch =>
+              // 为整批数据一次性生成负样本池
+              val excludeVids = batch.flatMap(row =>
+                Seq(row.getAs[String]("vid_left"), row.getAs[String]("vid_right"))
+              ).toSet
+
+              // 预先过滤一次负样本池
+              val candidatePool = localVids.filter(vid => !excludeVids.contains(vid))
+
+              batch.flatMap { row =>
+                val vid_left = row.getAs[String]("vid_left")
+                // 从已过滤的池中快速采样
+                random.shuffle(candidatePool.take(1000))
+                  .take(20)
+                  .map(negative_vid => Row(vid_left, negative_vid, null, null, null, null, null, null, null))
+              }
+            }
+          }
+      }, "Generating negative samples")
+      // 转换回DataFrame
+      val negativeSamplesDF = spark.createDataFrame(negativeSamplesRDD, schema)
+      stats.negativeSamplesCount = negativeSamplesRDD.count()
 
-      // 将RDD转换为DataFrame
-      val numberedVid = spark.createDataFrame(vidsRDD.map(vid => Row(vid)),
-          StructType(Seq(StructField("vid", StringType, true))))
-        .withColumn("group_id", floor((row_number().over(Window.orderBy(rand())) - 1) / 100))
-        .persist(StorageLevel.MEMORY_AND_DISK_SER)
-
-      val negativeSamples = numberedPos
-        .join(numberedVid,
-          floor((col("pos_id") - 1) / 100) === col("group_id"))
-        .where(col("vid") =!= col("vid_left") && col("vid") =!= col("vid_right"))
-        .select(
-          lit(null).as("extend"),
-          lit(null).as("view_24h"),
-          lit(null).as("total_return_uv"),
-          lit(null).as("ts_right"),
-          lit(null).as("apptype"),
-          lit(null).as("pagesource"),
-          lit(null).as("mid"),
-          col("vid_left"),
-          col("vid").as("vid_right")
-        )
-        .withColumn("rn", row_number().over(Window.partitionBy("vid_left").orderBy(rand())))
-      // 4. 筛选每个vid_left对应的2个负样本
-      val filteredNegative = negativeSamples
-        .where(col("rn") <= 2)
-        .drop("rn")
 
       // 5. 合并正负样本
       val allSamples = positivePairs
         .withColumn("label", lit(1))
         .withColumn("logid", concat(lit("pos_"), monotonically_increasing_id()))
         .union(
-          filteredNegative
+          negativeSamplesDF
             .withColumn("label", lit(0))
             .withColumn("logid", concat(lit("neg_"), monotonically_increasing_id()))
-        )
-        .persist(StorageLevel.MEMORY_AND_DISK_SER)
+        ).persist(StorageLevel.MEMORY_AND_DISK_SER)
 
       // 6. 获取左侧特征
       // 读取L1类别统计特征
-      /*
       val l1CatStatFeatures = {
         val rdd = odpsOps.readTable(
           project = "loghubods",
@@ -365,12 +371,11 @@ object video_dssm_sampler {
           col("vid_right_cate_l2_feature")
         )
         .persist(StorageLevel.MEMORY_AND_DISK_SER)
-      */
+
 
 
       // 保存结果到HDFS
-      //val resultWithDt = result.withColumn("dt", lit(s"$dt"))
-      val resultWithDt = allSamples.withColumn("dt", lit(s"$dt"))
+      val resultWithDt = result.withColumn("dt", lit(s"$dt"))
       resultWithDt.write
         .mode("overwrite")
         .partitionBy("dt")
@@ -380,9 +385,8 @@ object video_dssm_sampler {
 
       // 8. 清理缓存
       positivePairs.unpersist()
-      negativeSamples.unpersist()
+
       allSamples.unpersist()
-      /*
       l1CatStatFeatures.unpersist()
       l2CatStatFeatures.unpersist()
       tagFeatures.unpersist()
@@ -394,8 +398,6 @@ object video_dssm_sampler {
       statRightFeatures.unpersist()
       categoryRightWithStats.unpersist()
       result.unpersist()
-      */
-
       // 输出统计信息
       stats.logStats()