Browse Source

修改逻辑

xueyiming 2 weeks ago
parent
commit
9805c8e784

+ 34 - 51
src/main/scala/com/aliyun/odps/spark/examples/makedata_ad/v20240718/makedata_ad_33_diffFeature_20250708.scala

@@ -88,72 +88,55 @@ object makedata_ad_33_diffFeature_20250708 {
     val columns: Array[Column] = schema.getColumns.toArray(Array.empty[Column])
 
     // 2. 遍历列,找到 "mid" 字段的索引
-    var pqtIdIndex = -1
+    var pdtIdIndex = -1
     for (i <- columns.indices) {
       if (columns(i).getName == "pqtid") {
-        pqtIdIndex = i
+        pdtIdIndex = i
       }
     }
 
-    // 3. 检查 mid 字段是否存在
-    if (pqtIdIndex == -1) {
-      throw new IllegalArgumentException("表中不存在 'mid' 字段,请检查字段名")
+    // 4. 处理 mid 字段(支持非字符串类型转换为字符串)
+    val midColumn = columns(pdtIdIndex)
+    val midType = midColumn.getTypeInfo.getTypeName // 获取 mid 字段类型
+    val midValue: Any = midType match {
+      case "STRING" => record.getString(pdtIdIndex)
+      case "BIGINT" => record.getBigint(pdtIdIndex)  // 长整型
+      case "DOUBLE" => record.getDouble(pdtIdIndex)  // 浮点型
+      case "BOOLEAN" => record.getBoolean(pdtIdIndex) // 布尔型
     }
+    val mid = Option(midValue).map(_.toString).getOrElse("") // 转换为字符串,null 转为空字符串
 
-    val pqtId = Option(record.get(pqtIdIndex))
-      .map(_.toString) // 非 null 值转为字符串
-      .getOrElse("") // null 值返回空字符串(或其他默认值)
-
-    // 5. 将 Record 转换为 Map[String, String](跳过 mid 字段)
+    // 5. 处理所有字段(非字符串类型转为字符串)
     val recordMap = columns.zipWithIndex
       .map { case (column, index) =>
-        // 获取字段值,保留 null(不转换为空字符串)
-        val value: String = record.get(index) match {
-          case null => null // 保留 null 值
-          case value => value.toString // 非 null 值转换为字符串
+        val columnName = column.getName
+        val columnType = column.getTypeInfo.getTypeName // 获取字段类型
+
+        // 根据字段类型获取值并转换为字符串
+        val value: String = columnType match {
+          case "STRING" =>
+            val str = record.getString(index)
+            if (str == null) null else str // 字符串类型直接保留(null 保持 null)
+
+          case "BIGINT" =>
+            val num = record.getBigint(index)
+            if (num == null) null else num.toString // 长整型转字符串
+
+          case "DOUBLE" =>
+            val num = record.getDouble(index)
+            if (num == null) null else num.toString // 浮点型转字符串
+
+          case "BOOLEAN" =>
+            val bool = record.getBoolean(index)
+            if (bool == null) null else bool.toString // 布尔型转字符串("true"/"false")
         }
 
-        column.getName -> value
+        columnName -> value
       }
       .toMap
 
     // 6. 返回 (mid, Map[String, String])
-    (pqtId, recordMap)
-  }
-
-  def write(map: Map[String, String], record: Record, schema: TableSchema): Unit = {
-    for ((columnName, value) <- map) {
-      try {
-        // 查找列名在表结构中的索引
-        val columnIndex = schema.getColumnIndex(columnName.toLowerCase)
-        // 获取列的类型
-        val columnType = schema.getColumn(columnIndex).getTypeInfo
-        try {
-          columnType.getTypeName match {
-            case "STRING" =>
-              record.setString(columnIndex, value)
-            case "BIGINT" =>
-              record.setBigint(columnIndex, value.toLong)
-            case "DOUBLE" =>
-              record.setDouble(columnIndex, value.toDouble)
-            case "BOOLEAN" =>
-              record.setBoolean(columnIndex, value.toBoolean)
-            case other =>
-              throw new IllegalArgumentException(s"Unsupported column type: $other")
-          }
-        } catch {
-          case e: NumberFormatException =>
-            println(s"Error converting value $value to type ${columnType.getTypeName} for column $columnName: ${e.getMessage}")
-          case e: Exception =>
-            println(s"Unexpected error writing value $value to column $columnName: ${e.getMessage}")
-        }
-      } catch {
-        case e: IllegalArgumentException => {
-          println(e.getMessage)
-        }
-      }
-    }
+    (mid, recordMap)
   }
 
-
 }