|
@@ -12,6 +12,13 @@ import scala.util.Random
|
|
import scala.collection.mutable.ArrayBuffer
|
|
import scala.collection.mutable.ArrayBuffer
|
|
import org.apache.log4j.{Level, Logger}
|
|
import org.apache.log4j.{Level, Logger}
|
|
|
|
|
|
|
|
+import scala.util.Random
|
|
|
|
+import scala.collection.mutable.ArrayBuffer
|
|
|
|
+import org.apache.spark.sql.{Row, DataFrame}
|
|
|
|
+import org.apache.spark.sql.types.StructType
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+
|
|
object video_dssm_sampler {
|
|
object video_dssm_sampler {
|
|
private val logger = Logger.getLogger(this.getClass)
|
|
private val logger = Logger.getLogger(this.getClass)
|
|
|
|
|
|
@@ -154,26 +161,49 @@ object video_dssm_sampler {
|
|
val negativeSamplesRDD = time({
|
|
val negativeSamplesRDD = time({
|
|
positivePairs.rdd
|
|
positivePairs.rdd
|
|
.mapPartitions { iter =>
|
|
.mapPartitions { iter =>
|
|
- // 每次处理一批数据,而不是单条处理
|
|
|
|
val batchSize = 100
|
|
val batchSize = 100
|
|
- val localVids = allVidsBroadcast.value.toArray // 只转换一次为数组
|
|
|
|
|
|
+ val localVids = allVidsBroadcast.value.toArray // 保持使用Array
|
|
val random = new Random()
|
|
val random = new Random()
|
|
|
|
|
|
iter.grouped(batchSize).flatMap { batch =>
|
|
iter.grouped(batchSize).flatMap { batch =>
|
|
- // 为整批数据一次性生成负样本池
|
|
|
|
|
|
+ // 收集需要排除的视频ID
|
|
val excludeVids = batch.flatMap(row =>
|
|
val excludeVids = batch.flatMap(row =>
|
|
Seq(row.getAs[String]("vid_left"), row.getAs[String]("vid_right"))
|
|
Seq(row.getAs[String]("vid_left"), row.getAs[String]("vid_right"))
|
|
).toSet
|
|
).toSet
|
|
|
|
|
|
- // 预先过滤一次负样本池
|
|
|
|
|
|
+ // 过滤候选池
|
|
val candidatePool = localVids.filter(vid => !excludeVids.contains(vid))
|
|
val candidatePool = localVids.filter(vid => !excludeVids.contains(vid))
|
|
|
|
|
|
batch.flatMap { row =>
|
|
batch.flatMap { row =>
|
|
val vid_left = row.getAs[String]("vid_left")
|
|
val vid_left = row.getAs[String]("vid_left")
|
|
- // 从已过滤的池中快速采样
|
|
|
|
- random.shuffle(candidatePool.take(1000))
|
|
|
|
- .take(20)
|
|
|
|
- .map(negative_vid => Row(vid_left, negative_vid, null, null, null, null, null, null, null))
|
|
|
|
|
|
+
|
|
|
|
+ // 随机采样逻辑
|
|
|
|
+ val negatives = new ArrayBuffer[String]()
|
|
|
|
+ val candidateSize = Math.min(1000, candidatePool.length)
|
|
|
|
+
|
|
|
|
+ // 随机采样20个负样本
|
|
|
|
+ while (negatives.size < 20 && negatives.size < candidateSize) {
|
|
|
|
+ val randomIndex = random.nextInt(candidatePool.length)
|
|
|
|
+ val candidate = candidatePool(randomIndex)
|
|
|
|
+ if (!negatives.contains(candidate)) {
|
|
|
|
+ negatives += candidate
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // 转换为Row对象
|
|
|
|
+ negatives.map { negative_vid =>
|
|
|
|
+ Row(
|
|
|
|
+ vid_left,
|
|
|
|
+ negative_vid,
|
|
|
|
+ null, // 其他字段设为null
|
|
|
|
+ null,
|
|
|
|
+ null,
|
|
|
|
+ null,
|
|
|
|
+ null,
|
|
|
|
+ null,
|
|
|
|
+ null
|
|
|
|
+ )
|
|
|
|
+ }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|