Jelajahi Sumber

scala train

zhangbo 8 bulan lalu
induk
melakukan
e1be781b02

+ 90 - 0
recommend-model-produce/src/main/scala/model/train_01_xgb_ad_20240808.scala

@@ -0,0 +1,90 @@
+package model
+
+import org.apache.commons.lang.math.NumberUtils
+import org.apache.commons.lang3.StringUtils
+import org.apache.spark.ml.feature.VectorAssembler
+import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
+import org.apache.spark.sql.{Dataset, Row, RowFactory, SparkSession}
+import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
+import org.apache.spark.rdd.RDD
+
+import java.util
+import java.util.{ArrayList, HashMap, List, Map}
+
+object train_01_xgb_ad_20240808{
+  def main(args: Array[String]): Unit = {
+    val spark = SparkSession
+      .builder()
+      .appName(this.getClass.getName)
+      .getOrCreate()
+    val sc = spark.sparkContext
+    val features = Array("cpa", "b2_12h_ctr", "b2_12h_ctcvr", "b2_12h_cvr", "b2_12h_conver", "b2_12h_click", "b2_12h_conver*log(view)", "b2_12h_conver*ctcvr", "b2_7d_ctr", "b2_7d_ctcvr", "b2_7d_cvr", "b2_7d_conver", "b2_7d_click", "b2_7d_conver*log(view)", "b2_7d_conver*ctcvr")
+
+    val trainData = createData(
+      sc.textFile("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00099.gz"),
+      features
+    )
+    println("train data size:" + trainData.count())
+
+    val fields = new util.ArrayList[StructField]
+    fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true))
+    for (f <- features) {
+      fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true))
+    }
+    val schema = DataTypes.createStructType(fields)
+    val trainDataSet = spark.createDataFrame(trainData, schema)
+    val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
+    val xgbInput = vectorAssembler.transform(trainDataSet).select("features","label")
+    val xgbParam = Map("eta" -> 0.01f,
+      "max_depth" -> 5,
+      "objective" -> "binary:logistic",
+      "num_class" -> 3)
+    val xgbClassifier = new XGBoostClassifier()
+      .setEta(0.01f)
+      .setMissing(0.0f)
+      .setMaxDepth(5)
+      .setNumRound(100)
+      .setObjective("binary:logistic")
+      .setEvalMetric("auc")
+      .setFeaturesCol("features")
+      .setLabelCol("label")
+      .setNthread(1)
+      .setNumWorkers(1)
+    val model = xgbClassifier.fit(xgbInput)
+
+
+    val testData = createData(
+      sc.textFile("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00098.gz"),
+      features
+    )
+    val testDataSet = spark.createDataFrame(testData, schema)
+    val testDataSetTrans = vectorAssembler.transform(testDataSet).select("features","label")
+    val predictions = model.transform(testDataSetTrans)
+
+    val saveData = predictions.select("label", "prediction", "features", "rawPrediction", "probability").rdd
+      .map(r =>{
+        (r.get(0), r.get(1), r.get(2), r.get(3), r.get(4)).productIterator.mkString("\t")
+    })
+    saveData.repartition(1).saveAsTextFile("/dw/recommend/model/checkpoint_xgbtest")
+  }
+
+
+  def createData(data: RDD[String], features: Array[String]): RDD[Row] = {
+    data.map(r => {
+      val line: Array[String] = StringUtils.split(r, '\t')
+      val label: Int = NumberUtils.toInt(line(0))
+      val map: util.Map[String, Double] = new util.HashMap[String, Double]
+      for (i <- 1 until line.length) {
+        val fv: Array[String] = StringUtils.split(line(i), ':')
+        map.put(fv(0), NumberUtils.toDouble(fv(1), 0.0))
+      }
+
+      val v: Array[AnyRef] = new Array[AnyRef](features.length + 1)
+      v(0) = label
+      for (i <- 0 until features.length) {
+        v(i + 1) = map.getOrDefault(features(i), 0.0d)
+      }
+      RowFactory.create(v)
+    })
+  }
+}