Parcourir la source

get neg sample

often il y a 5 mois
Parent
commit
176c27550b

+ 109 - 80
src/main/scala/com/aliyun/odps/spark/examples/makedata_recsys/video_dssm_sampler.scala

@@ -122,94 +122,123 @@ object video_dssm_sampler {
       }, "Generating negative samples")
 
       // 5. 合并正负样本
-      val allSamples = time({
-        positivePairs
-          .withColumn("label", lit(1))
-          .withColumn("logid", concat(lit("pos_"), monotonically_increasing_id()))
-          .union(
-            negativeSamplesDF
-              .withColumn("label", lit(0))
-              .withColumn("logid", concat(lit("neg_"), monotonically_increasing_id()))
-          )
-          .persist(StorageLevel.MEMORY_AND_DISK)
-      }, "Merging positive and negative samples")
+      val allSamples = positivePairs
+        .withColumn("label", lit(1))
+        .withColumn("logid", concat(lit("pos_"), monotonically_increasing_id()))
+        .union(
+          negativeSamplesDF
+            .withColumn("label", lit(0))
+            .withColumn("logid", concat(lit("neg_"), monotonically_increasing_id()))
+        ).persist(StorageLevel.MEMORY_AND_DISK)
 
-      // 6. 获取特征数据
-      val features = time({
-        spark.sql(s"""
+      // 6. 获取左侧特征
+      // 获取tag特征
+      val tagFeatures = spark.sql(s"""
+        SELECT vid, feature as vid_left_tag_feature
+        FROM loghubods.t_vid_tag_feature
+        WHERE dt = '$dt'
+      """)
+
+      // 获取统计特征
+      val statFeatures = spark.sql(s"""
+        SELECT vid, feature as vid_left_stat_feature
+        FROM loghubods.t_vid_stat_feature
+        WHERE dt = '$dt'
+      """)
+
+      // 获取类别特征
+      val categoryFeatures = spark.sql(s"""
+        SELECT
+          a.vid,
+          b.feature as vid_left_cate_l1_feature,
+          c.feature as vid_left_cate_l2_feature
+        FROM (
           SELECT
-            t.vid,
-            t.feature as tag_feature,
-            s.feature as stat_feature,
-            c.cate_l1_feature,
-            c.cate_l2_feature
-          FROM loghubods.t_vid_tag_feature t
-          LEFT JOIN loghubods.t_vid_stat_feature s ON t.vid = s.vid AND s.dt = '$dt'
-          LEFT JOIN (
-            SELECT
-              a.vid,
-              b.feature as cate_l1_feature,
-              c.feature as cate_l2_feature
-            FROM (
-              SELECT
-                vid,
-                get_json_object(feature,'$.category1') as category1,
-                get_json_object(feature,'$.category2_1') as category2
-              FROM loghubods.t_vid_tag_feature
-              WHERE dt = '$dt'
-            ) a
-            LEFT JOIN (
-              SELECT category1, feature
-              FROM t_vid_l1_cat_stat_feature
-              WHERE dt = '$dt'
-            ) b ON a.category1 = b.category1
-            LEFT JOIN (
-              SELECT category2, feature
-              FROM t_vid_l2_cat_stat_feature
-              WHERE dt = '$dt'
-            ) c ON a.category2 = c.category2
-          ) c ON t.vid = c.vid
-          WHERE t.dt = '$dt'
-        """).persist(StorageLevel.MEMORY_AND_DISK)
-      }, "Loading features")
-
-      // 7. 添加特征并保存结果
-      time({
-        val result = allSamples
-          .join(broadcast(features).as("left_features"),
-            col("vid_left") === col("left_features.vid"),
-            "left")
-          .select(
-            allSamples("*"),
-            col("left_features.tag_feature").as("vid_left_tag_feature"),
-            col("left_features.stat_feature").as("vid_left_stat_feature"),
-            col("left_features.cate_l1_feature").as("vid_left_cate_l1_feature"),
-            col("left_features.cate_l2_feature").as("vid_left_cate_l2_feature")
-          )
-          .join(broadcast(features).as("right_features"),
-            col("vid_right") === col("right_features.vid"),
-            "left")
-          .select(
-            col("*"),
-            col("right_features.tag_feature").as("vid_right_tag_feature"),
-            col("right_features.stat_feature").as("vid_right_stat_feature"),
-            col("right_features.cate_l1_feature").as("vid_right_cate_l1_feature"),
-            col("right_features.cate_l2_feature").as("vid_right_cate_l2_feature")
-          )
+            vid,
+            get_json_object(feature,"$$.category1") as category1,
+            get_json_object(feature,"$$.category2_1") as category2
+          FROM loghubods.t_vid_tag_feature
+          WHERE dt = '$dt'
+        ) a
+        LEFT JOIN (
+          SELECT category1, feature
+          FROM t_vid_l1_cat_stat_feature
+          WHERE dt = '$dt'
+        ) b ON a.category1 = b.category1
+        LEFT JOIN (
+          SELECT category2, feature
+          FROM t_vid_l2_cat_stat_feature
+          WHERE dt = '$dt'
+        ) c ON a.category2 = c.category2
+      """)
+
+      // 合并所有左侧特征
+      val vidLeftFeatures = allSamples
+        .join(broadcast(tagFeatures), col("vid_left") === tagFeatures("vid"), "left")
+        .join(broadcast(statFeatures), col("vid_left") === statFeatures("vid"), "left")
+        .join(broadcast(categoryFeatures), col("vid_left") === categoryFeatures("vid"), "left")
+        .persist(StorageLevel.MEMORY_AND_DISK)
+
+      // 7. 获取右侧特征并生成最终结果
+      // 获取tag特征
+      val rightTagFeatures = spark.sql(s"""
+        SELECT vid, feature as vid_right_tag_feature
+        FROM loghubods.t_vid_tag_feature
+        WHERE dt = '$dt'
+      """)
+
+      // 获取统计特征
+      val rightStatFeatures = spark.sql(s"""
+        SELECT vid, feature as vid_right_stat_feature
+        FROM loghubods.t_vid_stat_feature
+        WHERE dt = '$dt'
+      """)
+
+      // 获取类别特征
+      val rightCategoryFeatures = spark.sql(s"""
+        SELECT
+          a.vid,
+          b.feature as vid_right_cate_l1_feature,
+          c.feature as vid_right_cate_l2_feature
+        FROM (
+          SELECT
+            vid,
+            get_json_object(feature,"$$.category1") as category1,
+            get_json_object(feature,"$$.category2_1") as category2
+          FROM loghubods.t_vid_tag_feature
+          WHERE dt = '$dt'
+        ) a
+        LEFT JOIN (
+          SELECT category1, feature
+          FROM t_vid_l1_cat_stat_feature
+          WHERE dt = '$dt'
+        ) b ON a.category1 = b.category1
+        LEFT JOIN (
+          SELECT category2, feature
+          FROM t_vid_l2_cat_stat_feature
+          WHERE dt = '$dt'
+        ) c ON a.category2 = c.category2
+      """)
+
+      // 合并所有右侧特征并生成最终结果
+      val result = vidLeftFeatures
+        .join(broadcast(rightTagFeatures), col("vid_right") === rightTagFeatures("vid"), "left")
+        .join(broadcast(rightStatFeatures), col("vid_right") === rightStatFeatures("vid"), "left")
+        .join(broadcast(rightCategoryFeatures), col("vid_right") === rightCategoryFeatures("vid"), "left")
 
-        // 保存到HDFS
-        result.write
-          .mode("overwrite")
-          .partitionBy("dt")
-          .parquet(s"$outputPath/dt=$dt")
+      // 保存结果到HDFS
+      result.write
+        .mode("overwrite")
+        .partitionBy("dt")
+        .parquet(s"$outputPath/dt=$dt")
 
-        logger.info(s"Results saved to: $outputPath/dt=$dt")
-      }, "Adding features and saving results")
+      logger.info(s"Results saved to: $outputPath/dt=$dt")
 
       // 8. 清理缓存
       positivePairs.unpersist()
       allSamples.unpersist()
-      features.unpersist()
+      vidLeftFeatures.unpersist()
+      rightTagFeatures.unpersist()
 
       // 输出统计信息
       stats.logStats()