瀏覽代碼

rov offline auc

jch 3 月之前
父節點
當前提交
16d1b7b5a8

+ 94 - 0
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/rov_offline_ab_auc.scala

@@ -0,0 +1,94 @@
+package com.tzld.piaoquan.recommend.model
+
+import com.alibaba.fastjson.JSON
+import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, SparkSession}
+
+object rov_offline_ab_auc {
+  def main(args: Array[String]): Unit = {
+    val spark = SparkSession
+      .builder()
+      .appName(this.getClass.getName)
+      .getOrCreate()
+    val sc = spark.sparkContext
+
+    val param = ParamUtils.parseArgs(args)
+    val testPath = param.getOrElse("testPath", "")
+    val whatApps = param.getOrElse("whatApps", "0,3,4,21,17").split(",").toSet
+    val baseAbCodes = param.getOrElse("baseAbCodes", "ab3,ab4,ab8,ab9").split(",").toSet
+    val expAbCodes = param.getOrElse("expAbCodes", "ab0,ab1,ab2,ab5,ab6,ab7").split(",").toSet
+    val baseScore = param.getOrElse("baseScore", "fmRov")
+    val expScore = param.getOrElse("expScore", "fmRov")
+
+    val testData = loadData(whatApps, sc.textFile(testPath))
+    val baseData = getSubData(spark, baseAbCodes, baseScore, testData)
+    val expData = getSubData(spark, expAbCodes, expScore, testData)
+
+    val evaluator = new BinaryClassificationEvaluator()
+      .setLabelCol("label")
+      .setRawPredictionCol("score")
+      .setMetricName("areaUnderROC")
+    val baseCnt = baseData.count()
+    val expCnt = expData.count()
+    val baseAuc = evaluator.evaluate(baseData.select("label", "score"))
+    val expAuc = evaluator.evaluate(expData.select("label", "score"))
+    printf("base count: %d, auc: %.6f\n", baseCnt, baseAuc)
+    printf("exp count: %d, auc: %.6f\n", expCnt, expAuc)
+    println("---------------------------------\n")
+    println("---------------------------------\n")
+  }
+
+
+  def loadData(whatApps: Set[String], data: RDD[String]): RDD[(String, Double, String)] = {
+    data
+      .map(r => {
+        // logKey + "\t" + label + "\t" + scoresMap + "\t" + featuresBucket.mkString("\t")
+        val rList = r.split("\t")
+        val logKey = rList(0)
+        val label = rList(1).toDouble
+        val scoresMap = rList(2)
+        (logKey, label, scoresMap)
+      })
+      .filter(raw => {
+        validApp(raw._1, whatApps)
+      })
+  }
+
+  private def validApp(logKey: String, whatApps: Set[String]): Boolean = {
+    // apptype, page, pagesource, recommendpagetype, flowpool, abcode, mid, vid, level, ts
+    val cells = logKey.split(",")
+    val apptype = cells(0)
+    val page = cells(1)
+    //val pagesource = cells(2)
+    val recommendpagetype = cells(3)
+    val flowpool = cells(4)
+    val abcode = cells(5)
+    if (whatApps.contains(apptype)) {
+      return true
+    }
+    false
+  }
+
+  private def parseScore(data: String, key: String, default: String = "-2"): Double = {
+    JSON.parseObject(data).getOrDefault(key, default).toString.toDouble
+  }
+
+  private def getSubData(spark: SparkSession, abCodes: Set[String], whatScore: String, data: RDD[(String, Double, String)]): DataFrame = {
+    import spark.implicits._
+    data
+      .filter(raw => {
+        var flag = false
+        val cells = raw._1.split(",")
+        if (abCodes.contains(cells(5))) {
+          val score = parseScore(raw._3, whatScore)
+          flag = score > -1
+        }
+        flag
+      })
+      .map(raw => {
+        (raw._1, raw._2, parseScore(raw._3, whatScore))
+      })
+      .toDF("logKey", "label", "score")
+  }
+}