|
@@ -12,6 +12,13 @@ import scala.util.Random
|
|
|
import scala.collection.mutable.ArrayBuffer
|
|
|
import org.apache.log4j.{Level, Logger}
|
|
|
|
|
|
+import scala.util.Random
|
|
|
+import scala.collection.mutable.ArrayBuffer
|
|
|
+import org.apache.spark.sql.{Row, DataFrame}
|
|
|
+import org.apache.spark.sql.types.StructType
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
object video_dssm_sampler {
|
|
|
private val logger = Logger.getLogger(this.getClass)
|
|
|
|
|
@@ -154,26 +161,49 @@ object video_dssm_sampler {
|
|
|
val negativeSamplesRDD = time({
|
|
|
positivePairs.rdd
|
|
|
.mapPartitions { iter =>
|
|
|
-
|
|
|
val batchSize = 100
|
|
|
- val localVids = allVidsBroadcast.value.toArray
|
|
|
+ val localVids = allVidsBroadcast.value.toArray
|
|
|
val random = new Random()
|
|
|
|
|
|
iter.grouped(batchSize).flatMap { batch =>
|
|
|
-
|
|
|
+
|
|
|
val excludeVids = batch.flatMap(row =>
|
|
|
Seq(row.getAs[String]("vid_left"), row.getAs[String]("vid_right"))
|
|
|
).toSet
|
|
|
|
|
|
-
|
|
|
+
|
|
|
val candidatePool = localVids.filter(vid => !excludeVids.contains(vid))
|
|
|
|
|
|
batch.flatMap { row =>
|
|
|
val vid_left = row.getAs[String]("vid_left")
|
|
|
-
|
|
|
- random.shuffle(candidatePool.take(1000))
|
|
|
- .take(20)
|
|
|
- .map(negative_vid => Row(vid_left, negative_vid, null, null, null, null, null, null, null))
|
|
|
+
|
|
|
+
|
|
|
+ val negatives = new ArrayBuffer[String]()
|
|
|
+ val candidateSize = Math.min(1000, candidatePool.length)
|
|
|
+
|
|
|
+
|
|
|
+ while (negatives.size < 20 && negatives.size < candidateSize) {
|
|
|
+ val randomIndex = random.nextInt(candidatePool.length)
|
|
|
+ val candidate = candidatePool(randomIndex)
|
|
|
+ if (!negatives.contains(candidate)) {
|
|
|
+ negatives += candidate
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ negatives.map { negative_vid =>
|
|
|
+ Row(
|
|
|
+ vid_left,
|
|
|
+ negative_vid,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ null
|
|
|
+ )
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|