jch 1 napja
szülő
commit
02ea98f31c

+ 71 - 0
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/stat_qq.scala

@@ -0,0 +1,71 @@
+package com.tzld.piaoquan.recommend.model
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.expressions.Window
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.{DataFrame, SparkSession}
+
+
+object stat_qq {
+  def main(args: Array[String]): Unit = {
+    val spark = SparkSession
+      .builder()
+      .appName(this.getClass.getName)
+      .getOrCreate()
+    val sc = spark.sparkContext
+
+    // param
+    val param = ParamUtils.parseArgs(args)
+    val predictPath = param.getOrElse("predictPath", "/dw/recommend/model/general_model/ad_post_conver/eval")
+    val bucketNum = param.getOrElse("bucketNum", "60").toInt
+    val savePath = param.getOrElse("savePath", "/dw/recommend/model/general_model/ad_post_conver/qq")
+
+    // data
+    val predictDF = loadData(spark, sc.textFile(predictPath))
+
+    // process
+    import spark.implicits._
+    val bucketDF = predictDF.withColumn("bucketId", ntile(bucketNum).over(Window.orderBy("predict")))
+      .select($"predict", $"label", $"bucketId")
+      .groupBy("bucketId")
+      .agg(min("predict").alias("min"), max("predict").as("max"), round(avg("predict"), 6).as("predict"), round(avg("label"), 6).as("real"), count("label").as("cnt"))
+      .orderBy("bucketId")
+
+    // save
+    val hdfsPath = savePath
+    if (hdfsPath.nonEmpty && hdfsPath.startsWith("/dw/recommend/model/")) {
+      println("删除路径并开始数据写入:" + hdfsPath)
+      MyHdfsUtils.delete_hdfs_path(hdfsPath)
+      bucketDF.write.format("csv").option("header", "true").save(hdfsPath)
+    } else {
+      println("路径不合法,无法写入:" + hdfsPath)
+    }
+  }
+
+  private def parseScore(data: String): String = {
+    if (data.nonEmpty) {
+      val pair = data.replace("[", "").replace("]", "").split(",")
+      if (pair.length > 1) {
+        return pair(1).toDouble.formatted("%.6f")
+      }
+    }
+    "-1"
+  }
+
+  def loadData(spark: SparkSession, data: RDD[String]): DataFrame = {
+    import spark.implicits._
+    data
+      .map(r => {
+        val cells = r.split("\t")
+        val label = cells(0)
+        val logit = cells(1)
+        val score = parseScore(cells(2)).toDouble
+        val mid = cells(3)
+        (score, label, mid)
+      })
+      .filter(raw => {
+        raw._1 > -0.1
+      })
+      .toDF("predict", "label", "logKey")
+  }
+}