|
@@ -149,6 +149,7 @@ object video_dssm_sampler {
|
|
|
|
|
|
|
|
|
val allVidsBroadcast = spark.sparkContext.broadcast(allVids)
|
|
|
+
|
|
|
// 注册UDF
|
|
|
spark.udf.register("generateNegativeVids", generateNegativeVidsUDF)
|
|
|
// 3. 定义UDF函数来生成负样本
|
|
@@ -157,7 +158,35 @@ object video_dssm_sampler {
|
|
|
val negativeVids = Random.shuffle(localAllVids - vid_left - vid_right).take(2) //20
|
|
|
negativeVids.toArray
|
|
|
})
|
|
|
+
|
|
|
// 4. 生成负样本
|
|
|
+ val negativeSamplesRDD = time({
|
|
|
+ positivePairs.rdd
|
|
|
+ .map(row => (row.getAs[String]("vid_left"), row.getAs[String]("vid_right")))
|
|
|
+ .flatMap { case (vid_left, vid_right) =>
|
|
|
+ val localAllVids = allVidsBroadcast.value
|
|
|
+ // 生成负样本对
|
|
|
+ Random.shuffle(localAllVids - vid_left - vid_right)
|
|
|
+ .take(2)
|
|
|
+ .map(negative_vid => Row(
|
|
|
+ vid_left, // vid_left
|
|
|
+ negative_vid, // vid_right
|
|
|
+ null, // extend
|
|
|
+ null, // view_24h
|
|
|
+ null, // total_return_uv
|
|
|
+ null, // ts_right
|
|
|
+ null, // apptype
|
|
|
+ null, // pagesource
|
|
|
+ null // mid
|
|
|
+ ))
|
|
|
+ }
|
|
|
+ }, "Generating negative samples")
|
|
|
+ // 转换回DataFrame
|
|
|
+ val negativeSamplesDF = spark.createDataFrame(negativeSamplesRDD, schema)
|
|
|
+ stats.negativeSamplesCount = negativeSamplesRDD.count()
|
|
|
+
|
|
|
+
|
|
|
+ /*
|
|
|
val negativeSamplesDF = time({
|
|
|
val df = positivePairs
|
|
|
.select("vid_left", "vid_right")
|
|
@@ -177,7 +206,7 @@ object video_dssm_sampler {
|
|
|
stats.negativeSamplesCount = df.count()
|
|
|
df
|
|
|
}, "Generating negative samples")
|
|
|
-
|
|
|
+ */
|
|
|
// 5. 合并正负样本
|
|
|
val allSamples = positivePairs
|
|
|
.withColumn("label", lit(1))
|