often 5 月之前
父节点
当前提交
4ba00f0991
共有 1 个文件被更改,包括 80 次插入27 次删除
  1. 80 27
      src/main/scala/com/aliyun/odps/spark/examples/makedata_recsys/video_dssm_sampler.scala

+ 80 - 27
src/main/scala/com/aliyun/odps/spark/examples/makedata_recsys/video_dssm_sampler.scala

@@ -54,9 +54,11 @@ object video_dssm_sampler {
   }
 
   def funcPositive(record: Record, schema: TableSchema): Row = {
-      Row(
-      record.getString("dt"),
-      record.getString("hh"),
+    //println(s"schema.getColumns.toString = ${schema.getColumns.toString}")
+    //println(s"Record: ${record.toString}")
+    Row(
+      //record.getString("dt"),
+      //record.getString("hh"),
       record.getString("vid_left"),
       record.getString("vid_right"),
       record.getString("extend"),
@@ -83,7 +85,10 @@ object video_dssm_sampler {
     val category2 = JsonPath.read[String](feature, "$.category2_1")
     (record.getString("vid"), category1, category2, feature)
   }
-
+  // 1. 首先定义L1和L2类别统计特征的读取函数
+  def funcCatStatFeatures(record: Record, schema: TableSchema): (String, String) = {
+    (record.getString("category1"), record.getString("feature"))  // 或 category2 取决于表
+  }
 
 
 
@@ -99,8 +104,8 @@ object video_dssm_sampler {
       // 1. 获取正样本对数据
 
       val schema = StructType(Array(
-        StructField("dt", StringType, true),
-        StructField("hh", StringType, true),
+        //StructField("dt", StringType, true),
+        //StructField("hh", StringType, true),
         StructField("vid_left", StringType, true),
         StructField("vid_right", StringType, true),
         StructField("extend", StringType, true),
@@ -111,15 +116,17 @@ object video_dssm_sampler {
         StructField("pagesource", StringType, true),
         StructField("mid", StringType, true)
       ))
-      logger.info(s"start read positivePairs for date $dt")
+      // logger.info(s"start read positivePairs for date $dt")
+      val partition = s"dt=20241124"
       val rdd = odpsOps.readTable(
         project = "loghubods",
         table = "alg_dssm_sample",
-        partition = s"dt='$dt'",
+        partition = s"dt=20241124,hh=08",
         transfer = funcPositive,
         numPartition = CONFIG("shuffle.partitions").toInt
       ).persist(StorageLevel.MEMORY_AND_DISK)
-      logger.info(s"11111111111111 $dt")
+      println("开始执行partiton:" + partition)
+
       val positivePairs = spark.createDataFrame(rdd, schema).persist(StorageLevel.MEMORY_AND_DISK)
       stats.positiveSamplesCount = positivePairs.count()
       logger.info(s"start read vid list for date $dt")
@@ -136,22 +143,21 @@ object video_dssm_sampler {
       }, "Loading all vids")
 
 
-
+      val allVidsBroadcast = spark.sparkContext.broadcast(allVids)
       // 注册UDF
       spark.udf.register("generateNegativeVids", generateNegativeVidsUDF)
       // 3. 定义UDF函数来生成负样本
       def generateNegativeVidsUDF = udf((vid_left: String, vid_right: String) => {
-        val negativeVids = Random.shuffle(allVids - vid_left - vid_right).take(20)
-        negativeVids
+        val localAllVids = allVidsBroadcast.value
+        val negativeVids = Random.shuffle(localAllVids - vid_left - vid_right).take(2) //20
+        negativeVids.toArray
       })
       // 4. 生成负样本
       val negativeSamplesDF = time({
         val df = positivePairs
-          .select("dt", "hh", "vid_left", "vid_right")
+          .select("vid_left", "vid_right")
           .withColumn("negative_vids", generateNegativeVidsUDF(col("vid_left"), col("vid_right")))
           .select(
-            col("dt"),
-            col("hh"),
             col("vid_left"),
             explode(col("negative_vids")).as("vid_right"),
             lit(null).as("extend"),
@@ -178,6 +184,37 @@ object video_dssm_sampler {
         ).persist(StorageLevel.MEMORY_AND_DISK)
 
       // 6. 获取左侧特征
+      // 读取L1类别统计特征
+      val l1CatStatFeatures = {
+        val rdd = odpsOps.readTable(
+          project = "loghubods",
+          table = "t_vid_l1_cat_stat_feature",
+          partition = s"dt='$dt'",
+          transfer = funcCatStatFeatures,
+          numPartition = CONFIG("shuffle.partitions").toInt
+        )
+        val schema = StructType(Array(
+          StructField("category1", StringType, true),
+          StructField("cate_l1_feature", StringType, true)
+        ))
+        spark.createDataFrame(rdd.map(t => Row(t._1, t._2)), schema)
+      }
+
+      // 读取L2类别统计特征
+      val l2CatStatFeatures = {
+        val rdd = odpsOps.readTable(
+          project = "loghubods",
+          table = "t_vid_l2_cat_stat_feature",
+          partition = s"dt='$dt'",
+          transfer = funcCatStatFeatures,
+          numPartition = CONFIG("shuffle.partitions").toInt
+        )
+        val schema = StructType(Array(
+          StructField("category2", StringType, true),
+          StructField("cate_l2_feature", StringType, true)
+        ))
+        spark.createDataFrame(rdd.map(t => Row(t._1, t._2)), schema)
+      }
       // 获取tag特征
 
       val tagFeatures = {
@@ -211,15 +248,13 @@ object video_dssm_sampler {
         ))
         spark.createDataFrame(rdd.map(t => Row(t._1, t._2)), schema)
       }
-
-
-      // 读取并处理类别特征
       val categoryData = {
+        // 从 t_vid_tag_feature 表读取数据
         val rdd = odpsOps.readTable(
           project = "loghubods",
           table = "t_vid_tag_feature",
           partition = s"dt='$dt'",
-          transfer = funcCategoryFeatures,
+          transfer = funcCategoryFeatures,  // 使用之前定义的 funcCategoryFeatures 函数
           numPartition = CONFIG("shuffle.partitions").toInt
         )
 
@@ -234,19 +269,37 @@ object video_dssm_sampler {
         // 将 RDD 转换为 DataFrame
         spark.createDataFrame(rdd.map(t => Row(t._1, t._2, t._3, t._4)), schema)
       }
-      // 合并所有左侧特征
+
+      val categoryWithStats = categoryData
+        .join(broadcast(l1CatStatFeatures), categoryData("category1") === l1CatStatFeatures("category1"), "left")
+        .join(broadcast(l2CatStatFeatures), categoryData("category2") === l2CatStatFeatures("category2"), "left")
+        .select(
+          col("vid"),
+          col("cate_l1_feature").as("vid_left_cate_l1_feature"),
+          col("cate_l2_feature").as("vid_left_cate_l2_feature")
+        )
+
       val vidLeftFeatures = allSamples
         .join(broadcast(tagFeatures), col("vid_left") === tagFeatures("vid"), "left")
+        .drop(tagFeatures("vid"))
         .join(broadcast(statFeatures), col("vid_left") === statFeatures("vid"), "left")
-        .join(broadcast(categoryData), col("vid_left") === categoryData("vid"), "left")
+        .drop(statFeatures("vid"))
+        .join(broadcast(categoryWithStats), col("vid_left") === categoryWithStats("vid"), "left")
+        .drop(categoryWithStats("vid"))
+        .select(
+          allSamples("*"),
+          col("vid_left_tag_feature"),
+          col("vid_left_stat_feature"),
+          col("vid_left_cate_l1_feature"),
+          col("vid_left_cate_l2_feature")
+        )
         .persist(StorageLevel.MEMORY_AND_DISK)
 
-
       // 合并所有右侧特征并生成最终结果
       val result = vidLeftFeatures
-        .join(broadcast(tagFeatures), col("vid_right") === tagFeatures("vid"), "left")
-        .join(broadcast(statFeatures), col("vid_right") === statFeatures("vid"), "left")
-        .join(broadcast(categoryData), col("vid_right") === categoryData("vid"), "left")
+      //  .join(broadcast(tagFeatures), col("vid_right") === tagFeatures("vid"), "left")
+      //  .join(broadcast(statFeatures), col("vid_right") === statFeatures("vid"), "left")
+      //  .join(broadcast(categoryData), col("vid_right") === categoryData("vid"), "left")
 
       // 保存结果到HDFS
       result.write
@@ -300,9 +353,9 @@ object video_dssm_sampler {
   }
 }
 
-
 // 运行
 // 上传代码
+// ssh 192.168.141.208
 // git clone
 // mvn clean install
-//
+// /opt/apps/SPARK2/spark-2.4.8-hadoop3.2-1.0.8/bin/spark-class2 org.apache.spark.deploy.SparkSubmit   --class com.aliyun.odps.spark.examples.makedata_recsys.video_dssm_sampler   --master yarn --driver-memory 1G --executor-memory 2G --executor-cores 1 --num-executors 16   ./target/spark-examples-1.0.0-SNAPSHOT-shaded.jar    20241120 /dw/recommend/dssm_model/