浏览代码

添加logKey

jch 4 月之前
父节点
当前提交
c9c461b1a6

+ 10 - 6
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/pred_recsys_61_xgb_nor_hdfsfile_20241209.scala

@@ -47,9 +47,12 @@ object pred_recsys_61_xgb_nor_hdfsfile_20241209 {
       .filter(r => r.nonEmpty || !featureFilter.contains(r))
       .filter(r => r.nonEmpty || !featureFilter.contains(r))
     println("features.size=" + features.length)
     println("features.size=" + features.length)
 
 
-    val fields = Array(
+    var fields = Array(
       DataTypes.createStructField("label", DataTypes.DoubleType, true)
       DataTypes.createStructField("label", DataTypes.DoubleType, true)
     ) ++ features.map(f => DataTypes.createStructField(f, DataTypes.DoubleType, true))
     ) ++ features.map(f => DataTypes.createStructField(f, DataTypes.DoubleType, true))
+    fields = fields ++ Array(
+      DataTypes.createStructField("logKey", DataTypes.StringType, true)
+    )
 
 
     val schema = DataTypes.createStructType(fields)
     val schema = DataTypes.createStructType(fields)
     val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
     val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
@@ -63,13 +66,13 @@ object pred_recsys_61_xgb_nor_hdfsfile_20241209 {
     )
     )
 
 
     val testDataSet = spark.createDataFrame(testData, schema)
     val testDataSet = spark.createDataFrame(testData, schema)
-    val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features", "label")
+    val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features", "label", "logKey")
     val predictions = model.transform(testDataSetTrans)
     val predictions = model.transform(testDataSetTrans)
     val clipPrediction = getClipData(spark, predictions).persist()
     val clipPrediction = getClipData(spark, predictions).persist()
 
 
-    val saveData = clipPrediction.select("label", "prediction", "clipPrediction").rdd
+    val saveData = clipPrediction.select("label", "prediction", "clipPrediction", "logKey").rdd
       .map(r => {
       .map(r => {
-        (r.get(0), r.get(1), r.get(2)).productIterator.mkString("\t")
+        (r.get(0), r.get(1), r.get(2), r.get(3)).productIterator.mkString("\t")
       })
       })
     val hdfsPath = savePath
     val hdfsPath = savePath
     if (hdfsPath.nonEmpty && hdfsPath.startsWith("/dw/recommend/model/")) {
     if (hdfsPath.nonEmpty && hdfsPath.startsWith("/dw/recommend/model/")) {
@@ -104,7 +107,7 @@ object pred_recsys_61_xgb_nor_hdfsfile_20241209 {
   def createData(data: RDD[String], features: Array[String]): RDD[Row] = {
   def createData(data: RDD[String], features: Array[String]): RDD[Row] = {
     data.map(r => {
     data.map(r => {
       val line: Array[String] = StringUtils.split(r, '\t')
       val line: Array[String] = StringUtils.split(r, '\t')
-      // val logKey = line(0)
+      val logKey = line(0)
       val label: Double = NumberUtils.toDouble(line(1))
       val label: Double = NumberUtils.toDouble(line(1))
       val map: util.Map[String, Double] = new util.HashMap[String, Double]
       val map: util.Map[String, Double] = new util.HashMap[String, Double]
       for (i <- 2 until line.length) {
       for (i <- 2 until line.length) {
@@ -112,11 +115,12 @@ object pred_recsys_61_xgb_nor_hdfsfile_20241209 {
         map.put(fv(0), NumberUtils.toDouble(fv(1), 0.0))
         map.put(fv(0), NumberUtils.toDouble(fv(1), 0.0))
       }
       }
 
 
-      val v: Array[Any] = new Array[Any](features.length + 1)
+      val v: Array[Any] = new Array[Any](features.length + 2)
       v(0) = label
       v(0) = label
       for (i <- 0 until features.length) {
       for (i <- 0 until features.length) {
         v(i + 1) = map.getOrDefault(features(i), 0.0d)
         v(i + 1) = map.getOrDefault(features(i), 0.0d)
       }
       }
+      v(features.length + 1) = logKey
       Row(v: _*)
       Row(v: _*)
     })
     })
   }
   }