|
@@ -0,0 +1,128 @@
|
|
|
+package com.aliyun.odps.spark.ad.xgboost.v20240808
|
|
|
+
|
|
|
+import com.aliyun.odps.spark.examples.myUtils.ParamUtils
|
|
|
+import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
|
|
|
+import org.apache.commons.lang3.StringUtils
|
|
|
+import org.apache.commons.lang3.math.NumberUtils
|
|
|
+import org.apache.spark.ml.feature.VectorAssembler
|
|
|
+import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
|
|
|
+import org.apache.spark.sql.{Row, SparkSession}
|
|
|
+
|
|
|
+import java.net.URL
|
|
|
+import scala.io.Source
|
|
|
+import scala.reflect.ClassTag.Any
|
|
|
+
|
|
|
+object XGBoostTrain {
|
|
|
+ def main(args: Array[String]): Unit = {
|
|
|
+ try {
|
|
|
+
|
|
|
+ val param = ParamUtils.parseArgs(args)
|
|
|
+
|
|
|
+ val spark = SparkSession.builder()
|
|
|
+ .appName("XGBoostTrain")
|
|
|
+ .getOrCreate()
|
|
|
+ val sc = spark.sparkContext
|
|
|
+
|
|
|
+ val loader = getClass.getClassLoader
|
|
|
+
|
|
|
+ val readPath = param.getOrElse("readPath", "")
|
|
|
+ val filterNameSet = param.getOrElse("filterNames", "").split(",").filter(_.nonEmpty).toSet
|
|
|
+ val featureNameFile = param.getOrElse("featureNameFile", "20240718_ad_feature_name.txt")
|
|
|
+
|
|
|
+ val featureNameContent = readFile(loader.getResource(featureNameFile))
|
|
|
+
|
|
|
+ val featureNameList = featureNameContent.split("\n")
|
|
|
+ .map(r => r.replace(" ", "").replaceAll("\n", ""))
|
|
|
+ .filter(r => r.nonEmpty)
|
|
|
+ .filter(r => !containsAny(filterNameSet, r))
|
|
|
+ .toList
|
|
|
+
|
|
|
+ val rowRDD = sc.textFile(readPath).map(r => {
|
|
|
+ val line = r.split("\t")
|
|
|
+
|
|
|
+ val label = NumberUtils.toInt(line(0))
|
|
|
+
|
|
|
+ val map = line.drop(1).map { entry =>
|
|
|
+ val Array(key, value) = StringUtils.split(entry, ':')
|
|
|
+ key -> NumberUtils.toDouble(value, 0.0)
|
|
|
+ }.toMap
|
|
|
+
|
|
|
+ val v = Array.ofDim[Any](featureNameList.length + 1)
|
|
|
+ v(0) = label
|
|
|
+
|
|
|
+ for (index <- featureNameList.indices) {
|
|
|
+ v(index + 1) = map.getOrElse(featureNameList(1), 0.0)
|
|
|
+ }
|
|
|
+
|
|
|
+ Row.fromSeq(v)
|
|
|
+ })
|
|
|
+ println(s"rowRDD count ${rowRDD.count()}")
|
|
|
+
|
|
|
+ val fields = Seq(
|
|
|
+ StructField("label", DataTypes.IntegerType, true)
|
|
|
+ ) ++ featureNameFile.map(f => StructField(f.toString, DataTypes.DoubleType, true))
|
|
|
+
|
|
|
+ val dataset = spark.createDataFrame(rowRDD, StructType(fields))
|
|
|
+
|
|
|
+ val assembler = new VectorAssembler()
|
|
|
+ .setInputCols(featureNameList.toArray)
|
|
|
+ .setOutputCol("features")
|
|
|
+
|
|
|
+ val assembledData = assembler.transform(dataset)
|
|
|
+ assembledData.show()
|
|
|
+
|
|
|
+ // 划分训练集和测试集
|
|
|
+ val Array(trainData, testData) = assembledData.randomSplit(Array(0.7, 0.3))
|
|
|
+ trainData.show()
|
|
|
+ testData.show()
|
|
|
+
|
|
|
+ // 创建 XGBoostClassifier 对象
|
|
|
+ val xgbClassifier = new XGBoostClassifier()
|
|
|
+ .setEta(0.1f)
|
|
|
+ .setMissing(0.0f)
|
|
|
+ .setFeaturesCol("features")
|
|
|
+ .setLabelCol("label")
|
|
|
+ .setMaxDepth(5)
|
|
|
+ .setObjective("binary:logistic")
|
|
|
+ .setNthread(1)
|
|
|
+ .setNumRound(5)
|
|
|
+ .setNumWorkers(1)
|
|
|
+
|
|
|
+ // 训练模型
|
|
|
+ val model = xgbClassifier.fit(trainData)
|
|
|
+
|
|
|
+ // 显示预测结果
|
|
|
+ val predictions = model.transform(testData)
|
|
|
+ predictions.show(100)
|
|
|
+ }
|
|
|
+ catch {
|
|
|
+ case e: Throwable => e.printStackTrace()
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private def readFile(filePath: URL): String = {
|
|
|
+ var source: Option[Source] = None
|
|
|
+ try {
|
|
|
+ source = Some(Source.fromURL(filePath))
|
|
|
+ source.get.getLines().mkString("\n")
|
|
|
+ }
|
|
|
+ catch {
|
|
|
+ case e: Exception => {
|
|
|
+ println("文件读取异常: " + e.toString)
|
|
|
+ ""
|
|
|
+ }
|
|
|
+ }
|
|
|
+ finally {
|
|
|
+ source.foreach(_.close())
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private def containsAny(list: Iterable[String], s: String): Boolean = {
|
|
|
+ for (item <- list) {
|
|
|
+ if (s.contains(item)) {
|
|
|
+ return true
|
|
|
+ }
|
|
|
+ }
|
|
|
+ false
|
|
|
+ }
|
|
|
+}
|