Browse Source

dec memory cost

often 5 months ago
parent
commit
7e9796d29f

+ 55 - 3
src/main/scala/com/aliyun/odps/spark/examples/makedata_recsys/video_dssm_sampler.scala

@@ -7,7 +7,7 @@ import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
 import com.jayway.jsonpath.JsonPath
-
+import org.apache.spark.sql.expressions.Window
 import scala.util.Random
 import scala.collection.mutable.ArrayBuffer
 import org.apache.log4j.{Level, Logger}
@@ -136,6 +136,58 @@ object video_dssm_sampler {
       stats.positiveSamplesCount = positivePairs.count()
       logger.info(s"start read vid list for date $dt")
 
+      val numberedPos = positivePairs
+        .withColumn("pos_id", row_number().over(Window.orderBy("vid_left", "vid_right")))
+      // 2. 给候选视频表按100条分组编号
+      val vidsRDD = odpsOps.readTable(
+        project = "loghubods",
+        table = "t_vid_tag_feature",
+        partition = s"dt='$dt'",
+        transfer = funcVids,
+        numPartition = CONFIG("shuffle.partitions").toInt
+      )
+
+      // 将RDD转换为DataFrame
+      val numberedVid = spark.createDataFrame(vidsRDD.map(vid => Row(vid)),
+          StructType(Seq(StructField("vid", StringType, true))))
+        .withColumn("group_id", floor((row_number().over(Window.orderBy(rand())) - 1) / 100))
+        .persist(StorageLevel.MEMORY_AND_DISK_SER)
+
+      val negativeSamples = numberedPos
+        .join(numberedVid,
+          floor((col("pos_id") - 1) / 100) === col("group_id"))
+        .where(col("vid") =!= col("vid_left") && col("vid") =!= col("vid_right"))
+        .select(
+          col("dt"),
+          col("hh"),
+          lit(null).as("extend"),
+          lit(null).as("view_24h"),
+          lit(null).as("total_return_uv"),
+          lit(null).as("ts_right"),
+          lit(null).as("apptype"),
+          lit(null).as("pagesource"),
+          lit(null).as("mid"),
+          col("vid_left"),
+          col("vid").as("vid_right")
+        )
+        .withColumn("rn", row_number().over(Window.partitionBy("vid_left").orderBy(rand())))
+      // 4. 筛选每个vid_left对应的2个负样本
+      val filteredNegative = negativeSamples
+        .where(col("rn") <= 2)
+        .drop("rn")
+
+      // 5. 合并正负样本
+      val allSamples = positivePairs
+        .withColumn("label", lit(1))
+        .withColumn("logid", concat(lit("pos_"), monotonically_increasing_id()))
+        .union(
+          filteredNegative
+            .withColumn("label", lit(0))
+            .withColumn("logid", concat(lit("neg_"), monotonically_increasing_id()))
+        )
+        .persist(StorageLevel.MEMORY_AND_DISK_SER)
+
+      /*
       // 2. 获取所有可用的vid列表
       val allVids = time({
         odpsOps.readTable(
@@ -216,7 +268,7 @@ object video_dssm_sampler {
             .withColumn("label", lit(0))
             .withColumn("logid", concat(lit("neg_"), monotonically_increasing_id()))
         ).persist(StorageLevel.MEMORY_AND_DISK_SER)
-
+*/
       // 6. 获取左侧特征
       // 读取L1类别统计特征
       val l1CatStatFeatures = {
@@ -409,7 +461,7 @@ object video_dssm_sampler {
 
       // 8. 清理缓存
       positivePairs.unpersist()
-      negativeSamplesDF.unpersist()
+      negativeSamples.unpersist()
       allSamples.unpersist()
       l1CatStatFeatures.unpersist()
       l2CatStatFeatures.unpersist()