Browse Source

增加样本特征到临时表中

xueyiming 2 weeks ago
parent
commit
7fc35e1f2d

+ 137 - 0
src/main/scala/com/aliyun/odps/spark/examples/makedata_ad/v20240718/makedata_ad_33_addFeatureToHive_20250708.scala

@@ -0,0 +1,137 @@
+package com.aliyun.odps.spark.examples.makedata_ad.v20240718
+
+import com.aliyun.odps.{Column, TableSchema}
+import com.aliyun.odps.data.Record
+import com.aliyun.odps.spark.examples.makedata_ad.v20240718.makedata_ad_33_bucketDataFromOriginToHive_20250522.write
+import com.aliyun.odps.spark.examples.myUtils.{MyDateUtils, ParamUtils, env}
+import org.apache.spark.sql.SparkSession
+import shapeless.syntax.std.tuple.productTupleOps
+
+object makedata_ad_33_addFeatureToHive_20250708 {
+  def main(args: Array[String]): Unit = {
+    val spark = SparkSession
+      .builder()
+      .appName(this.getClass.getName)
+      .getOrCreate()
+    val sc = spark.sparkContext
+
+    val param = ParamUtils.parseArgs(args)
+    val project = param.getOrElse("project", "loghubods")
+    val tablePart = param.getOrElse("tablePart", "64").toInt
+    val beginStr = param.getOrElse("beginStr", "20250708")
+    val endStr = param.getOrElse("endStr", "20250708")
+
+    val odpsOps = env.getODPS(sc)
+    val dateRange = MyDateUtils.getDateRange(beginStr, endStr)
+    for (dt <- dateRange) {
+      val partition = s"dt=$dt"
+      // 读取样本表
+      val sampleRdd = odpsOps.readTable(
+        project = project,
+        table = "ad_easyrec_train_realtime_data_v3_sampled",
+        partition = partition,
+        transfer = func,
+        numPartition = tablePart)
+
+      // 读取特征表
+      val featureRdd = odpsOps.readTable(
+        project = project, // 替换为实际项目名
+        table = "ad_transformed_class_base_data_all",
+        partition = partition, // 替换为实际分区
+        transfer = func,
+        numPartition = tablePart // 特征表通常较小
+      )
+
+      // 将 RDD 转换为键值对,保留完整元组
+      val samplePairRdd = sampleRdd.map(row => (row._1, row)) // (mid, 完整样本元组)
+      val featurePairRdd = featureRdd.map(row => (row._1, row)) // (mid, 完整特征元组)
+
+      // 3. 基于 mid 关联两个 RDD,并合并 Map
+      val recordRdd = samplePairRdd.join(featurePairRdd).map {
+        case (mid, (sampleMap, featureMap)) =>
+          val sampleData = sampleMap._2 // 提取 sampleMap 的第二个元素
+          val featureData = featureMap._2 // 提取 featureMap 的第二个元素
+
+          // 合并两个 Map,重复键以 featureData 为准
+          val mergedMap = sampleData ++ featureData
+          mergedMap
+      }
+      odpsOps.saveToTable(project, "ad_easyrec_train_realtime_data_v3_sampled_temp", partition, recordRdd, write, defaultCreate = true, overwrite = true)
+    }
+
+
+  }
+
+  def func(record: Record, schema: TableSchema): (String, Map[String, String]) = {
+    // 1. 获取所有列信息
+    val columns: Array[Column] = schema.getColumns.toArray(Array.empty[Column])
+
+    // 2. 遍历列,找到 "mid" 字段的索引
+    var midIndex = -1
+    for (i <- columns.indices) {
+      if (columns(i).getName == "mid") {
+        midIndex = i
+      }
+    }
+
+    // 3. 检查 mid 字段是否存在
+    if (midIndex == -1) {
+      throw new IllegalArgumentException("表中不存在 'mid' 字段,请检查字段名")
+    }
+
+    // 4. 提取 mid 的值,保留 null(不转换为空字符串)
+    val mid = record.get(midIndex).asInstanceOf[AnyRef].toString // 直接转换,null 会转为 "null" 字符串
+
+    // 5. 将 Record 转换为 Map[String, String](跳过 mid 字段)
+    val recordMap = columns.zipWithIndex
+      .map { case (column, index) =>
+        // 获取字段值,保留 null(不转换为空字符串)
+        val value: String = record.get(index) match {
+          case null => null // 保留 null 值
+          case value => value.toString // 非 null 值转换为字符串
+        }
+
+        column.getName -> value
+      }
+      .toMap
+
+    // 6. 返回 (mid, Map[String, String])
+    (mid, recordMap)
+  }
+
+  def write(map: Map[String, String], record: Record, schema: TableSchema): Unit = {
+    for ((columnName, value) <- map) {
+      try {
+        // 查找列名在表结构中的索引
+        val columnIndex = schema.getColumnIndex(columnName.toLowerCase)
+        // 获取列的类型
+        val columnType = schema.getColumn(columnIndex).getTypeInfo
+        try {
+          columnType.getTypeName match {
+            case "STRING" =>
+              record.setString(columnIndex, value)
+            case "BIGINT" =>
+              record.setBigint(columnIndex, value.toLong)
+            case "DOUBLE" =>
+              record.setDouble(columnIndex, value.toDouble)
+            case "BOOLEAN" =>
+              record.setBoolean(columnIndex, value.toBoolean)
+            case other =>
+              throw new IllegalArgumentException(s"Unsupported column type: $other")
+          }
+        } catch {
+          case e: NumberFormatException =>
+            println(s"Error converting value $value to type ${columnType.getTypeName} for column $columnName: ${e.getMessage}")
+          case e: Exception =>
+            println(s"Unexpected error writing value $value to column $columnName: ${e.getMessage}")
+        }
+      } catch {
+        case e: IllegalArgumentException => {
+          println(e.getMessage)
+        }
+      }
+    }
+  }
+
+
+}