Explorar el Código

feat:添加xgboosttrain代码

zhaohaipeng hace 8 meses
padre
commit
9e67052c9c

+ 5 - 1
pom.xml

@@ -176,7 +176,11 @@
             <artifactId>lombok</artifactId>
             <version>1.18.24</version>
         </dependency>
-
+        <dependency>
+            <groupId>ml.dmlc</groupId>
+            <artifactId>xgboost4j-spark_2.12</artifactId>
+            <version>1.7.4</version>
+        </dependency>
     </dependencies>
 
     <build>

+ 2 - 1
src/main/resources/20240718_ad_feature_name.txt

@@ -685,4 +685,5 @@ vid_rank_ecpm_1d
 vid_rank_ecpm_3d
 vid_rank_ecpm_7d
 vid_rank_ecpm_14d
-ctitle_vtitle_similarity
+ctitle_vtitle_similarity
+weight

+ 2 - 1
src/main/resources/20240718_ad_feature_name_517.txt

@@ -514,4 +514,5 @@ vid_rank_ctcvr_1d
 vid_rank_ctcvr_3d
 vid_rank_ctcvr_7d
 vid_rank_ctcvr_14d
-ctitle_vtitle_similarity
+ctitle_vtitle_similarity
+weight

+ 128 - 0
src/main/scala/com/aliyun/odps/spark/ad/xgboost/v20240808/XGBoostTrain.scala

@@ -0,0 +1,128 @@
+package com.aliyun.odps.spark.ad.xgboost.v20240808
+
+import com.aliyun.odps.spark.examples.myUtils.ParamUtils
+import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
+import org.apache.commons.lang3.StringUtils
+import org.apache.commons.lang3.math.NumberUtils
+import org.apache.spark.ml.feature.VectorAssembler
+import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
+import org.apache.spark.sql.{Row, SparkSession}
+
+import java.net.URL
+import scala.io.Source
+import scala.reflect.ClassTag.Any
+
+object XGBoostTrain {
+  def main(args: Array[String]): Unit = {
+    try {
+
+      val param = ParamUtils.parseArgs(args)
+
+      val spark = SparkSession.builder()
+        .appName("XGBoostTrain")
+        .getOrCreate()
+      val sc = spark.sparkContext
+
+      val loader = getClass.getClassLoader
+
+      val readPath = param.getOrElse("readPath", "")
+      val filterNameSet = param.getOrElse("filterNames", "").split(",").filter(_.nonEmpty).toSet
+      val featureNameFile = param.getOrElse("featureNameFile", "20240718_ad_feature_name.txt")
+
+      val featureNameContent = readFile(loader.getResource(featureNameFile))
+
+      val featureNameList = featureNameContent.split("\n")
+        .map(r => r.replace(" ", "").replaceAll("\n", ""))
+        .filter(r => r.nonEmpty)
+        .filter(r => !containsAny(filterNameSet, r))
+        .toList
+
+      val rowRDD = sc.textFile(readPath).map(r => {
+        val line = r.split("\t")
+
+        val label = NumberUtils.toInt(line(0))
+
+        val map = line.drop(1).map { entry =>
+          val Array(key, value) = StringUtils.split(entry, ':')
+          key -> NumberUtils.toDouble(value, 0.0)
+        }.toMap
+
+        val v = Array.ofDim[Any](featureNameList.length + 1)
+        v(0) = label
+
+        for (index <- featureNameList.indices) {
+          v(index + 1) = map.getOrElse(featureNameList(1), 0.0)
+        }
+
+        Row.fromSeq(v)
+      })
+      println(s"rowRDD count ${rowRDD.count()}")
+
+      val fields = Seq(
+        StructField("label", DataTypes.IntegerType, true)
+      ) ++ featureNameFile.map(f => StructField(f.toString, DataTypes.DoubleType, true))
+
+      val dataset = spark.createDataFrame(rowRDD, StructType(fields))
+
+      val assembler = new VectorAssembler()
+        .setInputCols(featureNameList.toArray)
+        .setOutputCol("features")
+
+      val assembledData = assembler.transform(dataset)
+      assembledData.show()
+
+      // 划分训练集和测试集
+      val Array(trainData, testData) = assembledData.randomSplit(Array(0.7, 0.3))
+      trainData.show()
+      testData.show()
+
+      // 创建 XGBoostClassifier 对象
+      val xgbClassifier = new XGBoostClassifier()
+        .setEta(0.1f)
+        .setMissing(0.0f)
+        .setFeaturesCol("features")
+        .setLabelCol("label")
+        .setMaxDepth(5)
+        .setObjective("binary:logistic")
+        .setNthread(1)
+        .setNumRound(5)
+        .setNumWorkers(1)
+
+      // 训练模型
+      val model = xgbClassifier.fit(trainData)
+
+      // 显示预测结果
+      val predictions = model.transform(testData)
+      predictions.show(100)
+    }
+    catch {
+      case e: Throwable => e.printStackTrace()
+    }
+  }
+
+  private def readFile(filePath: URL): String = {
+    var source: Option[Source] = None
+    try {
+      source = Some(Source.fromURL(filePath))
+      source.get.getLines().mkString("\n")
+    }
+    catch {
+      case e: Exception => {
+        println("文件读取异常: " + e.toString)
+        ""
+      }
+    }
+    finally {
+      source.foreach(_.close())
+    }
+  }
+
+  private def containsAny(list: Iterable[String], s: String): Boolean = {
+    for (item <- list) {
+      if (s.contains(item)) {
+        return true
+      }
+    }
+    false
+  }
+}