often 5 月之前
父节点
当前提交
5e3154d4cf
共有 1 个文件被更改,包括 30 次插入1 次删除
  1. 30 1
      src/main/scala/com/aliyun/odps/spark/examples/makedata_recsys/video_dssm_sampler.scala

+ 30 - 1
src/main/scala/com/aliyun/odps/spark/examples/makedata_recsys/video_dssm_sampler.scala

@@ -149,6 +149,7 @@ object video_dssm_sampler {
 
 
       val allVidsBroadcast = spark.sparkContext.broadcast(allVids)
+
       // 注册UDF
       spark.udf.register("generateNegativeVids", generateNegativeVidsUDF)
       // 3. 定义UDF函数来生成负样本
@@ -157,7 +158,35 @@ object video_dssm_sampler {
         val negativeVids = Random.shuffle(localAllVids - vid_left - vid_right).take(2) //20
         negativeVids.toArray
       })
+
       // 4. 生成负样本
+      val negativeSamplesRDD = time({
+        positivePairs.rdd
+          .map(row => (row.getAs[String]("vid_left"), row.getAs[String]("vid_right")))
+          .flatMap { case (vid_left, vid_right) =>
+            val localAllVids = allVidsBroadcast.value
+            // 生成负样本对
+            Random.shuffle(localAllVids - vid_left - vid_right)
+              .take(2)
+              .map(negative_vid => Row(
+                vid_left,          // vid_left
+                negative_vid,      // vid_right
+                null,             // extend
+                null,             // view_24h
+                null,             // total_return_uv
+                null,             // ts_right
+                null,             // apptype
+                null,             // pagesource
+                null              // mid
+              ))
+          }
+      }, "Generating negative samples")
+      // 转换回DataFrame
+      val negativeSamplesDF = spark.createDataFrame(negativeSamplesRDD, schema)
+      stats.negativeSamplesCount = negativeSamplesRDD.count()
+
+
+      /*
       val negativeSamplesDF = time({
         val df = positivePairs
           .select("vid_left", "vid_right")
@@ -177,7 +206,7 @@ object video_dssm_sampler {
         stats.negativeSamplesCount = df.count()
         df
       }, "Generating negative samples")
-
+      */
       // 5. 合并正负样本
       val allSamples = positivePairs
         .withColumn("label", lit(1))