浏览代码

odps read Table

often 5 月之前
父节点
当前提交
bc5008461b
共有 2 个文件被更改,包括 134 次插入113 次删除
  1. 5 2
      pom.xml
  2. 129 111
      src/main/scala/com/aliyun/odps/spark/examples/makedata_recsys/video_dssm_sampler.scala

+ 5 - 2
pom.xml

@@ -169,8 +169,11 @@
             <artifactId>jedis</artifactId>
             <artifactId>jedis</artifactId>
             <version>3.3.0</version>
             <version>3.3.0</version>
         </dependency>
         </dependency>
-
-
+        <dependency>
+            <groupId>com.jayway.jsonpath</groupId>
+            <artifactId>json-path</artifactId>
+            <version>2.8.0</version>
+        </dependency>
         <dependency>
         <dependency>
             <groupId>org.projectlombok</groupId>
             <groupId>org.projectlombok</groupId>
             <artifactId>lombok</artifactId>
             <artifactId>lombok</artifactId>

+ 129 - 111
src/main/scala/com/aliyun/odps/spark/examples/makedata_recsys/video_dssm_sampler.scala

@@ -1,11 +1,16 @@
 package com.aliyun.odps.spark.examples.makedata_recsys
 package com.aliyun.odps.spark.examples.makedata_recsys
-import org.apache.spark.sql.{DataFrame, SparkSession, Row}
+import com.aliyun.odps.TableSchema
+import com.aliyun.odps.data.Record
+import com.aliyun.odps.spark.examples.myUtils.{ParamUtils, env}
+import org.apache.spark.sql.{DataFrame, Row, SparkSession}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.storage.StorageLevel
+import com.jayway.jsonpath.JsonPath
+
 import scala.util.Random
 import scala.util.Random
 import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.ArrayBuffer
-import org.apache.log4j.{Logger, Level}
+import org.apache.log4j.{Level, Logger}
 
 
 object video_dssm_sampler {
 object video_dssm_sampler {
   private val logger = Logger.getLogger(this.getClass)
   private val logger = Logger.getLogger(this.getClass)
@@ -48,46 +53,83 @@ object video_dssm_sampler {
     spark
     spark
   }
   }
 
 
+  def funcPositive(record: Record, schema: TableSchema): Row = {
+      Row(
+      record.getString("dt"),
+      record.getString("hh"),
+      record.getString("vid_left"),
+      record.getString("vid_right"),
+      record.getString("extend"),
+      record.getString("view_24h"),
+      record.getString("total_return_uv"),
+      record.getString("ts_right"),
+      record.getString("apptype"),
+      record.getString("pagesource"),
+      record.getString("mid"))
+  }
+  def funcVids(record: Record, schema: TableSchema): String = {
+    record.getString("vid")
+  }
+  def funcTagFeatures(record: Record, schema: TableSchema): (String, String) = {
+    (record.getString("vid"), record.getString("feature"))
+  }
+  def funcStatFeatures(record: Record, schema: TableSchema): (String, String) = {
+    (record.getString("vid"), record.getString("feature"))
+  }
+  // 获取类别特征
+  def funcCategoryFeatures(record: Record, schema: TableSchema): (String, String, String, String) = {
+    val feature = record.getString("feature")
+    val category1 = JsonPath.read[String](feature, "$.category1")
+    val category2 = JsonPath.read[String](feature, "$.category2_1")
+    (record.getString("vid"), category1, category2, feature)
+  }
+
   def generateNegativeSamples(spark: SparkSession, dt: String, outputPath: String): Unit = {
   def generateNegativeSamples(spark: SparkSession, dt: String, outputPath: String): Unit = {
     val stats = ProcessingStats()
     val stats = ProcessingStats()
 
 
     try {
     try {
       logger.info(s"Starting negative sample generation for date: $dt")
       logger.info(s"Starting negative sample generation for date: $dt")
+      val sc = spark.sparkContext
 
 
-      // 1. 获取并缓存正样本对数据
+      // 2 读取odps+表信息
+      val odpsOps = env.getODPS(sc)
+      // 1. 获取正样本对数据
       val positivePairs = time({
       val positivePairs = time({
-        val df = spark.sql(s"""
-          SELECT
-            dt,
-            hh,
-            vid_left,
-            vid_right,
-            extend,
-            view_24h,
-            total_return_uv,
-            ts_right,
-            apptype,
-            pagesource,
-            mid
-          FROM loghubods.alg_dssm_sample
-          WHERE dt = '$dt'
-        """).persist(StorageLevel.MEMORY_AND_DISK)
-
+        val schema = StructType(Array(
+          StructField("dt", StringType, true),
+          StructField("hh", StringType, true),
+          StructField("vid_left", StringType, true),
+          StructField("vid_right", StringType, true),
+          StructField("extend", StringType, true),
+          StructField("view_24h", StringType, true),
+          StructField("total_return_uv", StringType, true),
+          StructField("ts_right", StringType, true),
+          StructField("apptype", StringType, true),
+          StructField("pagesource", StringType, true),
+          StructField("mid", StringType, true)
+        ))
+
+        val rdd = odpsOps.readTable(
+          project = "loghubods",
+          table = "alg_dssm_sample",
+          partition = s"dt='$dt'",
+          transfer = funcPositive,
+          numPartition = CONFIG("shuffle.partitions").toInt
+        ).persist(StorageLevel.MEMORY_AND_DISK)
+        val df = spark.createDataFrame(rdd, schema).persist(StorageLevel.MEMORY_AND_DISK)
         stats.positiveSamplesCount = df.count()
         stats.positiveSamplesCount = df.count()
         df
         df
       }, "Loading positive pairs")
       }, "Loading positive pairs")
 
 
       // 2. 获取所有可用的vid列表
       // 2. 获取所有可用的vid列表
       val allVids = time({
       val allVids = time({
-        spark.sql(s"""
-          SELECT vid
-          FROM loghubods.t_vid_tag_feature
-          WHERE dt = '$dt'
-        """)
-          .select("vid")
-          .collect()
-          .map(_.getString(0))
-          .toSet
+        odpsOps.readTable(
+          project = "loghubods",
+          table = "t_vid_tag_feature",
+          partition = s"dt='$dt'",
+          transfer = funcVids,
+          numPartition = CONFIG("shuffle.partitions").toInt
+        ).collect().toSet
       }, "Loading all vids")
       }, "Loading all vids")
 
 
       // 3. 定义UDF函数来生成负样本
       // 3. 定义UDF函数来生成负样本
@@ -134,98 +176,74 @@ object video_dssm_sampler {
 
 
       // 6. 获取左侧特征
       // 6. 获取左侧特征
       // 获取tag特征
       // 获取tag特征
-      val tagFeatures = spark.sql(s"""
-        SELECT vid, feature as vid_left_tag_feature
-        FROM loghubods.t_vid_tag_feature
-        WHERE dt = '$dt'
-      """)
 
 
-      // 获取统计特征
-      val statFeatures = spark.sql(s"""
-        SELECT vid, feature as vid_left_stat_feature
-        FROM loghubods.t_vid_stat_feature
-        WHERE dt = '$dt'
-      """)
-
-      // 获取类别特征
-      val categoryFeatures = spark.sql(s"""
-        SELECT
-          a.vid,
-          b.feature as vid_left_cate_l1_feature,
-          c.feature as vid_left_cate_l2_feature
-        FROM (
-          SELECT
-            vid,
-            get_json_object(feature,"$$.category1") as category1,
-            get_json_object(feature,"$$.category2_1") as category2
-          FROM loghubods.t_vid_tag_feature
-          WHERE dt = '$dt'
-        ) a
-        LEFT JOIN (
-          SELECT category1, feature
-          FROM t_vid_l1_cat_stat_feature
-          WHERE dt = '$dt'
-        ) b ON a.category1 = b.category1
-        LEFT JOIN (
-          SELECT category2, feature
-          FROM t_vid_l2_cat_stat_feature
-          WHERE dt = '$dt'
-        ) c ON a.category2 = c.category2
-      """)
+      val tagFeatures = {
+        val rdd = odpsOps.readTable(
+          project = "loghubods",
+          table = "t_vid_tag_feature",
+          partition = s"dt='$dt'",
+          transfer = funcTagFeatures,
+          numPartition = CONFIG("shuffle.partitions").toInt
+        )
+        val schema = StructType(Array(
+          StructField("vid", StringType, true),
+          StructField("vid_left_tag_feature", StringType, true)
+        ))
+        spark.createDataFrame(rdd.map(t => Row(t._1, t._2)), schema)
+      }
 
 
+
+      // 获取统计特征
+      val statFeatures = {
+        val rdd = odpsOps.readTable(
+          project = "loghubods",
+          table = "t_vid_stat_feature",
+          partition = s"dt='$dt'",
+          transfer = funcStatFeatures,
+          numPartition = CONFIG("shuffle.partitions").toInt
+        )
+        val schema = StructType(Array(
+          StructField("vid", StringType, true),
+          StructField("vid_left_stat_feature", StringType, true)
+        ))
+        spark.createDataFrame(rdd.map(t => Row(t._1, t._2)), schema)
+      }
+
+
+      // 读取并处理类别特征
+      val categoryData = {
+        val rdd = odpsOps.readTable(
+          project = "loghubods",
+          table = "t_vid_tag_feature",
+          partition = s"dt='$dt'",
+          transfer = funcCategoryFeatures,
+          numPartition = CONFIG("shuffle.partitions").toInt
+        )
+
+        // 定义 schema
+        val schema = StructType(Array(
+          StructField("vid", StringType, true),
+          StructField("category1", StringType, true),
+          StructField("category2", StringType, true),
+          StructField("feature", StringType, true)
+        ))
+
+        // 将 RDD 转换为 DataFrame
+        spark.createDataFrame(rdd.map(t => Row(t._1, t._2, t._3, t._4)), schema)
+      }
       // 合并所有左侧特征
       // 合并所有左侧特征
       val vidLeftFeatures = allSamples
       val vidLeftFeatures = allSamples
         .join(broadcast(tagFeatures), col("vid_left") === tagFeatures("vid"), "left")
         .join(broadcast(tagFeatures), col("vid_left") === tagFeatures("vid"), "left")
         .join(broadcast(statFeatures), col("vid_left") === statFeatures("vid"), "left")
         .join(broadcast(statFeatures), col("vid_left") === statFeatures("vid"), "left")
-        .join(broadcast(categoryFeatures), col("vid_left") === categoryFeatures("vid"), "left")
+        .join(broadcast(categoryData), col("vid_left") === categoryData("vid"), "left")
         .persist(StorageLevel.MEMORY_AND_DISK)
         .persist(StorageLevel.MEMORY_AND_DISK)
 
 
-      // 7. 获取右侧特征并生成最终结果
-      // 获取tag特征
-      val rightTagFeatures = spark.sql(s"""
-        SELECT vid, feature as vid_right_tag_feature
-        FROM loghubods.t_vid_tag_feature
-        WHERE dt = '$dt'
-      """)
-
-      // 获取统计特征
-      val rightStatFeatures = spark.sql(s"""
-        SELECT vid, feature as vid_right_stat_feature
-        FROM loghubods.t_vid_stat_feature
-        WHERE dt = '$dt'
-      """)
-
-      // 获取类别特征
-      val rightCategoryFeatures = spark.sql(s"""
-        SELECT
-          a.vid,
-          b.feature as vid_right_cate_l1_feature,
-          c.feature as vid_right_cate_l2_feature
-        FROM (
-          SELECT
-            vid,
-            get_json_object(feature,"$$.category1") as category1,
-            get_json_object(feature,"$$.category2_1") as category2
-          FROM loghubods.t_vid_tag_feature
-          WHERE dt = '$dt'
-        ) a
-        LEFT JOIN (
-          SELECT category1, feature
-          FROM t_vid_l1_cat_stat_feature
-          WHERE dt = '$dt'
-        ) b ON a.category1 = b.category1
-        LEFT JOIN (
-          SELECT category2, feature
-          FROM t_vid_l2_cat_stat_feature
-          WHERE dt = '$dt'
-        ) c ON a.category2 = c.category2
-      """)
 
 
       // 合并所有右侧特征并生成最终结果
       // 合并所有右侧特征并生成最终结果
       val result = vidLeftFeatures
       val result = vidLeftFeatures
-        .join(broadcast(rightTagFeatures), col("vid_right") === rightTagFeatures("vid"), "left")
-        .join(broadcast(rightStatFeatures), col("vid_right") === rightStatFeatures("vid"), "left")
-        .join(broadcast(rightCategoryFeatures), col("vid_right") === rightCategoryFeatures("vid"), "left")
+        .join(broadcast(tagFeatures), col("vid_right") === tagFeatures("vid"), "left")
+        .join(broadcast(statFeatures), col("vid_right") === statFeatures("vid"), "left")
+        .join(broadcast(categoryData), col("vid_right") === categoryData("vid"), "left")
 
 
       // 保存结果到HDFS
       // 保存结果到HDFS
       result.write
       result.write
@@ -239,7 +257,7 @@ object video_dssm_sampler {
       positivePairs.unpersist()
       positivePairs.unpersist()
       allSamples.unpersist()
       allSamples.unpersist()
       vidLeftFeatures.unpersist()
       vidLeftFeatures.unpersist()
-      rightTagFeatures.unpersist()
+
 
 
       // 输出统计信息
       // 输出统计信息
       stats.logStats()
       stats.logStats()