often hai 5 meses
pai
achega
2cf07a84fb

+ 38 - 8
src/main/scala/com/aliyun/odps/spark/examples/makedata_recsys/video_dssm_sampler.scala

@@ -12,6 +12,13 @@ import scala.util.Random
 import scala.collection.mutable.ArrayBuffer
 import org.apache.log4j.{Level, Logger}
 
+import scala.util.Random
+import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.sql.{Row, DataFrame}
+import org.apache.spark.sql.types.StructType
+
+
+
 object video_dssm_sampler {
   private val logger = Logger.getLogger(this.getClass)
 
@@ -154,26 +161,49 @@ object video_dssm_sampler {
       val negativeSamplesRDD = time({
         positivePairs.rdd
           .mapPartitions { iter =>
-            // 每次处理一批数据,而不是单条处理
             val batchSize = 100
-            val localVids = allVidsBroadcast.value.toArray // 只转换一次为数组
+            val localVids = allVidsBroadcast.value.toArray // 保持使用Array
             val random = new Random()
 
             iter.grouped(batchSize).flatMap { batch =>
-              // 为整批数据一次性生成负样本池
+              // 收集需要排除的视频ID
               val excludeVids = batch.flatMap(row =>
                 Seq(row.getAs[String]("vid_left"), row.getAs[String]("vid_right"))
               ).toSet
 
-              // 预先过滤一次负样本
+              // 过滤候选
               val candidatePool = localVids.filter(vid => !excludeVids.contains(vid))
 
               batch.flatMap { row =>
                 val vid_left = row.getAs[String]("vid_left")
-                // 从已过滤的池中快速采样
-                random.shuffle(candidatePool.take(1000))
-                  .take(20)
-                  .map(negative_vid => Row(vid_left, negative_vid, null, null, null, null, null, null, null))
+
+                // 随机采样逻辑
+                val negatives = new ArrayBuffer[String]()
+                val candidateSize = Math.min(1000, candidatePool.length)
+
+                // 随机采样20个负样本
+                while (negatives.size < 20 && negatives.size < candidateSize) {
+                  val randomIndex = random.nextInt(candidatePool.length)
+                  val candidate = candidatePool(randomIndex)
+                  if (!negatives.contains(candidate)) {
+                    negatives += candidate
+                  }
+                }
+
+                // 转换为Row对象
+                negatives.map { negative_vid =>
+                  Row(
+                    vid_left,
+                    negative_vid,
+                    null, // 其他字段设为null
+                    null,
+                    null,
+                    null,
+                    null,
+                    null,
+                    null
+                  )
+                }
               }
             }
           }