often 5 months ago
parent
commit
799f0352df

+ 79 - 9
src/main/scala/com/aliyun/odps/spark/examples/makedata_recsys/video_dssm_sampler.scala

@@ -244,8 +244,7 @@ object video_dssm_sampler {
       }
 
 
-      // 获取tag特征
-
+      // 获取左视频tag特征
       val tagFeatures = {
         val rdd = odpsOps.readTable(
           project = "loghubods",
@@ -262,7 +261,7 @@ object video_dssm_sampler {
       }
 
 
-      // 获取统计特征
+      // 获取左视频统计特征
       val statFeatures = {
         val rdd = odpsOps.readTable(
           project = "loghubods",
@@ -304,11 +303,73 @@ object video_dssm_sampler {
         )
         .persist(StorageLevel.MEMORY_AND_DISK)
 
-      // 合并所有右侧特征并生成最终结果
+
+
+
+
+
+      // 获取右视频tag特征
+      val tagRightFeatures = {
+        val rdd = odpsOps.readTable(
+          project = "loghubods",
+          table = "t_vid_tag_feature",
+          partition = s"dt='$dt'",
+          transfer = funcTagFeatures,
+          numPartition = CONFIG("shuffle.partitions").toInt
+        )
+        val schema = StructType(Array(
+          StructField("vid", StringType, true),
+          StructField("vid_right_tag_feature", StringType, true)
+        ))
+        spark.createDataFrame(rdd.map(t => Row(t._1, t._2)), schema)
+      }
+
+
+      // 获取左视频统计特征
+      val statRightFeatures = {
+        val rdd = odpsOps.readTable(
+          project = "loghubods",
+          table = "t_vid_stat_feature",
+          partition = s"dt='$dt'",
+          transfer = funcStatFeatures,
+          numPartition = CONFIG("shuffle.partitions").toInt
+        )
+        val schema = StructType(Array(
+          StructField("vid", StringType, true),
+          StructField("vid_right_stat_feature", StringType, true)
+        ))
+        spark.createDataFrame(rdd.map(t => Row(t._1, t._2)), schema)
+      }
+
+
+      val categoryRightWithStats = categoryData
+        .join(broadcast(l1CatStatFeatures), categoryData("category1") === l1CatStatFeatures("category1"), "left")
+        .join(broadcast(l2CatStatFeatures), categoryData("category2") === l2CatStatFeatures("category2"), "left")
+        .select(
+          col("vid"),
+          col("cate_l1_feature").as("vid_right_cate_l1_feature"),
+          col("cate_l2_feature").as("vid_right_cate_l2_feature")
+        )
+
       val result = vidLeftFeatures
-      //  .join(broadcast(tagFeatures), col("vid_right") === tagFeatures("vid"), "left")
-      //  .join(broadcast(statFeatures), col("vid_right") === statFeatures("vid"), "left")
-      //  .join(broadcast(categoryData), col("vid_right") === categoryData("vid"), "left")
+        .join(broadcast(tagRightFeatures), col("vid_right") === tagRightFeatures("vid"), "left")
+        .drop(tagRightFeatures("vid"))
+        .join(broadcast(statRightFeatures), col("vid_right") === statRightFeatures("vid"), "left")
+        .drop(statRightFeatures("vid"))
+        .join(broadcast(categoryRightWithStats), col("vid_right") === categoryRightWithStats("vid"), "left")
+        .drop(categoryRightWithStats("vid"))
+        .select(
+          vidLeftFeatures("*"),
+          col("vid_right_tag_feature"),
+          col("vid_right_stat_feature"),
+          col("vid_right_cate_l1_feature"),
+          col("vid_right_cate_l2_feature")
+        )
+        .persist(StorageLevel.MEMORY_AND_DISK)
+
+
+
+
 
       // 保存结果到HDFS
       val resultWithDt = result.withColumn("dt", lit(s"$dt"))
@@ -321,10 +382,19 @@ object video_dssm_sampler {
 
       // 8. 清理缓存
       positivePairs.unpersist()
+      negativeSamplesDF.unpersist()
       allSamples.unpersist()
+      l1CatStatFeatures.unpersist()
+      l2CatStatFeatures.unpersist()
+      tagFeatures.unpersist()
+      statFeatures.unpersist()
+      categoryData.unpersist()
+      categoryWithStats.unpersist()
       vidLeftFeatures.unpersist()
-
-
+      tagRightFeatures.unpersist()
+      statRightFeatures.unpersist()
+      categoryRightWithStats.unpersist()
+      result.unpersist()
       // 输出统计信息
       stats.logStats()