فهرست منبع

增加新特征字段

xueyiming 4 هفته پیش
والد
کامیت
1727e7d85f

+ 113 - 0
src/main/scala/com/aliyun/odps/spark/examples/makedata_ad/v20250813/makedata_ad_33_bucketData_add_Feature_20250813.scala

@@ -0,0 +1,113 @@
+package com.aliyun.odps.spark.examples.makedata_ad.v20250813
+
+import com.aliyun.odps.TableSchema
+import com.aliyun.odps.data.Record
+import com.aliyun.odps.spark.examples.myUtils.{MyDateUtils, ParamUtils, env}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SparkSession
+
+import scala.collection.JavaConverters._
+object makedata_ad_33_bucketData_add_Feature_20250813 {
+
+  def main(args: Array[String]): Unit = {
+    val spark = SparkSession
+      .builder()
+      .appName(this.getClass.getName)
+      .getOrCreate()
+    val sc = spark.sparkContext
+    val odpsOps = env.getODPS(sc)
+
+    // 1 读取参数
+    val param = ParamUtils.parseArgs(args)
+    val tablePart = param.getOrElse("tablePart", "64").toInt
+    val beginStr = param.getOrElse("beginStr", "20250812")
+    val endStr = param.getOrElse("endStr", "20250812")
+    val project = param.getOrElse("project", "loghubods")
+    val inputTable = param.getOrElse("inputTable", "ad_easyrec_train_realtime_data_v3_sampled")
+    val inputTable2 = param.getOrElse("inputTable2", "loghubods.advertiser")
+    val outputTable = param.getOrElse("inputTable2", "ad_easyrec_train_realtime_data_v3_sampled_temp")
+
+
+    // 3 循环执行数据生产
+    val dateRange = MyDateUtils.getDateRange(beginStr, endStr)
+    for (dt <- dateRange) {
+      val partition = s"dt=$dt"
+      val odpsData = odpsOps.readTable(project = project,
+        table = inputTable,
+        partition = partition,
+        transfer = func,
+        numPartition = tablePart)
+
+      val odpsData1 = odpsOps.readTable(project = project,
+        table = inputTable2,
+        transfer = func,
+        numPartition = tablePart)
+
+      val odpsDataPair: RDD[(Long, Record)] = odpsData.map(record => (record.getBigint("adverid"), record))
+      val odpsData1Pair: RDD[(Long, Record)] = odpsData1.map(record => (record.getBigint("id"), record))
+      val joinedRDD = odpsDataPair.leftOuterJoin(odpsData1Pair)
+      odpsOps.saveToTable(project, outputTable, partition, joinedRDD, write, defaultCreate = true, overwrite = true)
+
+
+    }
+
+
+  }
+
+  def write(data: (Long, (Record, Option[Record])),
+            record: Record,
+            schema: TableSchema): Unit = {
+
+    // 解构输入数据
+    val (_, (leftRecord, rightRecordOpt)) = data
+
+    // 定义需要从右表获取的指定字段名(根据实际需求修改)
+    val rightTableFields = Set("category_name")
+
+    // 获取schema所有字段名
+    val allFieldNames = schema.getColumns.asScala.map(_.getName)
+
+    // 遍历所有字段
+    allFieldNames.foreach { fieldName =>
+      if (rightTableFields.contains(fieldName)) {
+        // 指定字段:尝试从右表获取
+        rightRecordOpt match {
+          case Some(rightRecord) =>
+            // 获取字段索引和类型
+            val colIndex = schema.getColumnIndex(fieldName)
+            val colType = schema.getColumn(fieldName).getTypeInfo
+
+            // 根据类型复制值
+            colType.getTypeName match {
+              case "BIGINT"   => record.setBigint(colIndex, rightRecord.getBigint(fieldName))
+              case "STRING"   => record.setString(colIndex, rightRecord.getString(fieldName))
+              case "DOUBLE"   => record.setDouble(colIndex, rightRecord.getDouble(fieldName))
+              case "BOOLEAN"  => record.setBoolean(colIndex, rightRecord.getBoolean(fieldName))
+              case _ => throw new UnsupportedOperationException(s"Unsupported type: ${colType.getTypeName}")
+            }
+
+          case None =>
+            // 右表不存在时设为null
+            record.set(fieldName, null)
+        }
+      } else {
+        // 非指定字段:从左表获取
+        val colIndex = schema.getColumnIndex(fieldName)
+        val colType = schema.getColumn(fieldName).getTypeInfo
+
+        colType.getTypeName match {
+          case "BIGINT"   => record.setBigint(colIndex, leftRecord.getBigint(fieldName))
+          case "STRING"   => record.setString(colIndex, leftRecord.getString(fieldName))
+          case "DOUBLE"   => record.setDouble(colIndex, leftRecord.getDouble(fieldName))
+          case "BOOLEAN"  => record.setBoolean(colIndex, leftRecord.getBoolean(fieldName))
+          case _ => throw new UnsupportedOperationException(s"Unsupported type: ${colType.getTypeName}")
+        }
+      }
+    }
+  }
+
+  def func(record: Record, schema: TableSchema): Record = {
+    record
+  }
+
+}