Browse Source

limit rocord number

often 5 months ago
parent
commit
0850f8aa36

+ 28 - 21
src/main/scala/com/aliyun/odps/spark/examples/makedata_recsys/video_dssm_sampler.scala

@@ -116,6 +116,8 @@ object video_dssm_sampler {
         StructField("pagesource", StringType, true),
         StructField("mid", StringType, true)
       ))
+
+
       // logger.info(s"start read positivePairs for date $dt")
       val partition = s"dt=20241124"
       val rdd = odpsOps.readTable(
@@ -124,7 +126,8 @@ object video_dssm_sampler {
         partition = s"dt=20241124,hh=08",
         transfer = funcPositive,
         numPartition = CONFIG("shuffle.partitions").toInt
-      ).persist(StorageLevel.MEMORY_AND_DISK)
+      ).sample(false, 0.001) // 随机抽样千分之一的数据
+       .persist(StorageLevel.MEMORY_AND_DISK)
       println("开始执行partiton:" + partition)
 
       val positivePairs = spark.createDataFrame(rdd, schema).persist(StorageLevel.MEMORY_AND_DISK)
@@ -215,6 +218,30 @@ object video_dssm_sampler {
         ))
         spark.createDataFrame(rdd.map(t => Row(t._1, t._2)), schema)
       }
+
+      val categoryData = {
+        // 从 t_vid_tag_feature 表读取数据
+        val rdd = odpsOps.readTable(
+          project = "loghubods",
+          table = "t_vid_tag_feature",
+          partition = s"dt='$dt'",
+          transfer = funcCategoryFeatures,  // 使用之前定义的 funcCategoryFeatures 函数
+          numPartition = CONFIG("shuffle.partitions").toInt
+        )
+
+        // 定义 schema
+        val schema = StructType(Array(
+          StructField("vid", StringType, true),
+          StructField("category1", StringType, true),
+          StructField("category2", StringType, true),
+          StructField("feature", StringType, true)
+        ))
+
+        // 将 RDD 转换为 DataFrame
+        spark.createDataFrame(rdd.map(t => Row(t._1, t._2, t._3, t._4)), schema)
+      }
+
+
       // 获取tag特征
 
       val tagFeatures = {
@@ -248,27 +275,7 @@ object video_dssm_sampler {
         ))
         spark.createDataFrame(rdd.map(t => Row(t._1, t._2)), schema)
       }
-      val categoryData = {
-        // 从 t_vid_tag_feature 表读取数据
-        val rdd = odpsOps.readTable(
-          project = "loghubods",
-          table = "t_vid_tag_feature",
-          partition = s"dt='$dt'",
-          transfer = funcCategoryFeatures,  // 使用之前定义的 funcCategoryFeatures 函数
-          numPartition = CONFIG("shuffle.partitions").toInt
-        )
 
-        // 定义 schema
-        val schema = StructType(Array(
-          StructField("vid", StringType, true),
-          StructField("category1", StringType, true),
-          StructField("category2", StringType, true),
-          StructField("feature", StringType, true)
-        ))
-
-        // 将 RDD 转换为 DataFrame
-        spark.createDataFrame(rdd.map(t => Row(t._1, t._2, t._3, t._4)), schema)
-      }
 
       val categoryWithStats = categoryData
         .join(broadcast(l1CatStatFeatures), categoryData("category1") === l1CatStatFeatures("category1"), "left")