Переглянути джерело

feat:修改模型评估脚本

zhaohaipeng 1 тиждень тому
батько
коміт
f3ca85c488

+ 1 - 1
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/recsys_01_ros_binary_weight_xgb_train.scala → recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/ros/recsys_01_ros_binary_weight_xgb_train.scala

@@ -1,4 +1,4 @@
-package com.tzld.piaoquan.recommend.model
+package com.tzld.piaoquan.recommend.model.ros
 
 import com.alibaba.fastjson.JSON
 import com.tzld.piaoquan.recommend.utils.{MyHdfsUtils, ParamUtils}

+ 1 - 1
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/recsys_01_ros_multi_class_xgb_train.scala → recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/ros/recsys_01_ros_multi_class_xgb_train.scala

@@ -1,4 +1,4 @@
-package com.tzld.piaoquan.recommend.model
+package com.tzld.piaoquan.recommend.model.ros
 
 import com.alibaba.fastjson.JSON
 import com.tzld.piaoquan.recommend.model.produce.util.RosUtil

+ 1 - 1
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/recsys_01_ros_reg_weight_xgb_train.scala → recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/ros/recsys_01_ros_reg_weight_xgb_train.scala

@@ -1,4 +1,4 @@
-package com.tzld.piaoquan.recommend.model
+package com.tzld.piaoquan.recommend.model.ros
 
 import com.alibaba.fastjson.JSON
 import com.tzld.piaoquan.recommend.model.produce.util.RosUtil

+ 1 - 1
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/recsys_01_ros_reg_xgb_train.scala → recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/ros/recsys_01_ros_reg_xgb_train.scala

@@ -1,4 +1,4 @@
-package com.tzld.piaoquan.recommend.model
+package com.tzld.piaoquan.recommend.model.ros
 
 import com.alibaba.fastjson.JSON
 import com.tzld.piaoquan.recommend.model.produce.util.RosUtil

+ 1 - 1
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/recsys_02_ros_model_predict.scala → recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/ros/recsys_02_ros_model_predict.scala

@@ -1,4 +1,4 @@
-package com.tzld.piaoquan.recommend.model
+package com.tzld.piaoquan.recommend.model.ros
 
 import com.tzld.piaoquan.recommend.utils.{FileUtils, MyHdfsUtils, ParamUtils}
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel

+ 1 - 1
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/recsys_01_str_xgb_train.scala → recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/str/recsys_01_str_xgb_train.scala

@@ -1,4 +1,4 @@
-package com.tzld.piaoquan.recommend.model
+package com.tzld.piaoquan.recommend.model.str
 
 import com.alibaba.fastjson.JSON
 import com.tzld.piaoquan.recommend.utils.{MyHdfsUtils, ParamUtils}

+ 129 - 0
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/str/recsys_02_str_xgb_predict.scala

@@ -0,0 +1,129 @@
+package com.tzld.piaoquan.recommend.model.str
+
+import com.alibaba.fastjson.JSON
+import com.tzld.piaoquan.recommend.utils.{FileUtils, MyHdfsUtils, ParamUtils}
+import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostClassifier}
+import org.apache.commons.lang.math.NumberUtils
+import org.apache.commons.lang3.StringUtils
+import org.apache.hadoop.io.compress.GzipCodec
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
+import org.apache.spark.ml.feature.VectorAssembler
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.types.DataTypes
+import org.apache.spark.sql.{Dataset, Row, SparkSession}
+
+import java.time.LocalDateTime
+import java.time.format.DateTimeFormatter
+import java.util
+import scala.io.Source
+
+object recsys_02_str_xgb_predict {
+  def main(args: Array[String]): Unit = {
+
+    val dt = DateTimeFormatter.ofPattern("yyyyMMddHHmm").format(LocalDateTime.now())
+
+    val spark = SparkSession
+      .builder()
+      .appName(this.getClass.getName + " : " + dt)
+      .getOrCreate()
+    val sc = spark.sparkContext
+
+    val param = ParamUtils.parseArgs(args)
+    val featureFile = param.getOrElse("featureFile", "20240703_ad_feature_name.txt")
+    val testPath = param.getOrElse("testPath", "/dw/recommend/model/33_ad_train_data_v4/20240725")
+    val savePath = param.getOrElse("savePath", "/dw/recommend/model/34_ad_predict_data/")
+    val featureFilter = param.getOrElse("featureFilter", "XXXXXX").split(",").filter(_.nonEmpty).toList
+    val repartition = param.getOrElse("repartition", "20").toInt
+    val modelPath = param.getOrElse("modelPath", "/dw/recommend/model/45_recommend_model/")
+    val modelFile = param.getOrElse("modelFile", "model.tar.gz")
+
+    val loader = getClass.getClassLoader
+    val resourceUrl = loader.getResource(featureFile)
+    val content = FileUtils.readFile(resourceUrl)
+    println(content)
+
+    val features = content.split("\n")
+      .map(r => r.replace(" ", "").replaceAll("\n", ""))
+      .filter(r => r.nonEmpty || !featureFilter.contains(r))
+    println("features.size=" + features.length)
+
+
+    var fields = Array(
+      DataTypes.createStructField("label", DataTypes.IntegerType, true)
+    ) ++ features.map(f => DataTypes.createStructField(f, DataTypes.DoubleType, true))
+    fields = fields ++ Array(
+      DataTypes.createStructField("logKey", DataTypes.StringType, true)
+    )
+    val schema = DataTypes.createStructType(fields)
+    val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
+    val testData = createData4Ad(sc.textFile(testPath), features)
+
+    // 加载模型
+    val model = XGBoostClassificationModel.load(modelPath)
+    model.setMissing(0.0f).setFeaturesCol("features")
+
+    val testDataSet = spark.createDataFrame(testData, schema)
+    val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features", "label", "logKey")
+    val predictions = model.transform(testDataSetTrans)
+
+    println("zhangbo:columns:" + predictions.columns.mkString(","))
+
+    val saveData = predictions.select("label", "rawPrediction", "probability", "logKey").rdd
+      .map(r => {
+        (r.get(0), r.get(1), r.get(2), r.get(3)).productIterator.mkString("\t")
+      })
+    val hdfsPath = savePath
+    if (hdfsPath.nonEmpty && hdfsPath.startsWith("/dw/recommend/model/")) {
+      println("删除路径并开始数据写入:" + hdfsPath)
+      MyHdfsUtils.delete_hdfs_path(hdfsPath)
+      saveData.repartition(repartition).saveAsTextFile(hdfsPath, classOf[GzipCodec])
+    } else {
+      println("路径不合法,无法写入:" + hdfsPath)
+    }
+
+
+    val evaluator = new BinaryClassificationEvaluator()
+      .setLabelCol("label")
+      .setRawPredictionCol("probability")
+      .setMetricName("areaUnderROC")
+    val auc = evaluator.evaluate(predictions.select("label", "probability"))
+    println("zhangbo:auc:" + auc)
+
+    // 统计分cid的分数
+    sc.textFile(hdfsPath).map(r => {
+      val rList = r.split("\t")
+      val vid = JSON.parseObject(rList(3)).getString("vid")
+      val score = rList(2).replace("[", "").replace("]", "")
+        .split(",")(1).toDouble
+      val label = rList(0).toDouble
+      (vid, (1, label, score))
+    }).reduceByKey {
+      case (a, b) => (a._1 + b._1, a._2 + b._2, a._3 + b._3)
+    }.map {
+      case (vid, (all, zheng, scores)) =>
+        (vid, all, zheng, scores, zheng / all, scores / all)
+    }.collect().sortBy(_._1).map(_.productIterator.mkString("\t")).foreach(println)
+
+  }
+
+
+  def createData4Ad(data: RDD[String], features: Array[String]): RDD[Row] = {
+    data.map(r => {
+      val line: Array[String] = StringUtils.split(r, '\t')
+      val label: Int = NumberUtils.toInt(line(0))
+      val map: util.Map[String, Double] = new util.HashMap[String, Double]
+      for (i <- 1 until line.length - 1) {
+        val fv: Array[String] = StringUtils.split(line(i), ':')
+        map.put(fv(0), NumberUtils.toDouble(fv(1), 0.0))
+      }
+
+      val v: Array[Any] = new Array[Any](features.length + 2)
+      v(0) = label
+      for (i <- 0 until features.length) {
+        v(i + 1) = map.getOrDefault(features(i), 0.0d)
+      }
+      v(features.length + 1) = line(line.length - 1)
+      Row(v: _*)
+    })
+  }
+}