|
@@ -2,13 +2,16 @@ package com.aliyun.odps.spark.ad.xgboost.v20240808
|
|
|
|
|
|
import com.aliyun.odps.spark.examples.myUtils.ParamUtils
|
|
import com.aliyun.odps.spark.examples.myUtils.ParamUtils
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
|
|
|
|
+import org.apache.commons.lang.StringUtils
|
|
import org.apache.commons.lang3.math.NumberUtils
|
|
import org.apache.commons.lang3.math.NumberUtils
|
|
-import org.apache.spark.SparkConf
|
|
|
|
import org.apache.spark.ml.feature.VectorAssembler
|
|
import org.apache.spark.ml.feature.VectorAssembler
|
|
-import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
|
|
|
|
-import org.apache.spark.sql.{Row, SparkSession}
|
|
|
|
|
|
+import org.apache.spark.rdd.RDD
|
|
|
|
+import org.apache.spark.sql.types.{DataTypes, StructField}
|
|
|
|
+import org.apache.spark.sql.{Dataset, Row, SparkSession}
|
|
|
|
|
|
import java.net.URL
|
|
import java.net.URL
|
|
|
|
+import java.time.LocalDateTime
|
|
|
|
+import java.time.format.DateTimeFormatter
|
|
import scala.io.Source
|
|
import scala.io.Source
|
|
|
|
|
|
object XGBoostTrain {
|
|
object XGBoostTrain {
|
|
@@ -17,91 +20,64 @@ object XGBoostTrain {
|
|
|
|
|
|
val param = ParamUtils.parseArgs(args)
|
|
val param = ParamUtils.parseArgs(args)
|
|
|
|
|
|
- val conf = new SparkConf()
|
|
|
|
- .set("spark.yarn.appMasterEnv.PYSPARK_PYTHON", "/usr/bin/python2.7")
|
|
|
|
- .set("spark.yarn.appMasterEnv.PYSPARK_DRIVER_PYTHON", "/usr/bin/python2.7")
|
|
|
|
|
|
+ val dt = LocalDateTime.now.format(DateTimeFormatter.ofPattern("yyyyMMddHHmmSS"))
|
|
|
|
|
|
val spark = SparkSession.builder()
|
|
val spark = SparkSession.builder()
|
|
- .config(conf)
|
|
|
|
- .appName("XGBoostTrain")
|
|
|
|
|
|
+ .appName("XGBoostTrain:" + dt)
|
|
.getOrCreate()
|
|
.getOrCreate()
|
|
val sc = spark.sparkContext
|
|
val sc = spark.sparkContext
|
|
|
|
|
|
val loader = getClass.getClassLoader
|
|
val loader = getClass.getClassLoader
|
|
|
|
|
|
- val readPath = param.getOrElse("readPath", "")
|
|
|
|
|
|
+ val readPath = param.getOrElse("trainReadPath", "")
|
|
|
|
+ val predictReadPath = param.getOrElse("predictReadPath", "")
|
|
val filterNameSet = param.getOrElse("filterNames", "").split(",").filter(_.nonEmpty).toSet
|
|
val filterNameSet = param.getOrElse("filterNames", "").split(",").filter(_.nonEmpty).toSet
|
|
val featureNameFile = param.getOrElse("featureNameFile", "20240718_ad_feature_name.txt")
|
|
val featureNameFile = param.getOrElse("featureNameFile", "20240718_ad_feature_name.txt")
|
|
|
|
|
|
val featureNameContent = readFile(loader.getResource(featureNameFile))
|
|
val featureNameContent = readFile(loader.getResource(featureNameFile))
|
|
|
|
|
|
- val featureNameList = featureNameContent.split("\n")
|
|
|
|
|
|
+ val featureNameList: List[String] = featureNameContent.split("\n")
|
|
.map(r => r.replace(" ", "").replaceAll("\n", ""))
|
|
.map(r => r.replace(" ", "").replaceAll("\n", ""))
|
|
.filter(r => r.nonEmpty)
|
|
.filter(r => r.nonEmpty)
|
|
.filter(r => !containsAny(filterNameSet, r))
|
|
.filter(r => !containsAny(filterNameSet, r))
|
|
.toList
|
|
.toList
|
|
|
|
|
|
- val rowRDD = sc.textFile(readPath).map(r => {
|
|
|
|
- val line = r.split("\t")
|
|
|
|
|
|
+ val rowRDD = dataMap(sc.textFile(readPath), featureNameList)
|
|
|
|
|
|
- val label = NumberUtils.toInt(line(0))
|
|
|
|
-
|
|
|
|
- val map = line.drop(1).map { entry =>
|
|
|
|
- val Array(key, value) = entry.split(":")
|
|
|
|
- key -> NumberUtils.toDouble(value, 0.0)
|
|
|
|
- }.toMap
|
|
|
|
-
|
|
|
|
- val v = Array.ofDim[Any](featureNameList.length + 1)
|
|
|
|
- v(0) = label
|
|
|
|
-
|
|
|
|
- for (index <- featureNameList.indices) {
|
|
|
|
- v(index + 1) = map.getOrElse(featureNameList(index), 0.0)
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- Row.fromSeq(v)
|
|
|
|
- })
|
|
|
|
println(s"rowRDD count ${rowRDD.count()}")
|
|
println(s"rowRDD count ${rowRDD.count()}")
|
|
|
|
|
|
- val fields = Seq(
|
|
|
|
- StructField("label", DataTypes.IntegerType, true)
|
|
|
|
- ) ++ featureNameList.map(f => StructField(f, DataTypes.DoubleType, true))
|
|
|
|
-
|
|
|
|
- val dataset = spark.createDataFrame(rowRDD, StructType(fields))
|
|
|
|
|
|
+ val fields: Array[StructField] = Array(
|
|
|
|
+ DataTypes.createStructField("label", DataTypes.IntegerType, true)
|
|
|
|
+ ) ++ featureNameList.map(f => DataTypes.createStructField(f, DataTypes.DoubleType, true))
|
|
|
|
|
|
- val assembler = new VectorAssembler()
|
|
|
|
- .setInputCols(featureNameList.toArray)
|
|
|
|
- .setOutputCol("features")
|
|
|
|
|
|
+ val trainDataSet: Dataset[Row] = spark.createDataFrame(rowRDD, DataTypes.createStructType(fields))
|
|
|
|
|
|
- val assembledData = assembler.transform(dataset)
|
|
|
|
- assembledData.show()
|
|
|
|
|
|
+ val vectorAssembler = new VectorAssembler().setInputCols(featureNameList.toArray).setOutputCol("features")
|
|
|
|
|
|
- // 划分训练集和测试集
|
|
|
|
- val Array(trainData, testData) = assembledData.randomSplit(Array(0.7, 0.3))
|
|
|
|
- trainData.show()
|
|
|
|
- testData.show()
|
|
|
|
|
|
+ val xgbInput = vectorAssembler.transform(trainDataSet).select("features", "label")
|
|
|
|
+ xgbInput.show()
|
|
|
|
|
|
// 创建 XGBoostClassifier 对象
|
|
// 创建 XGBoostClassifier 对象
|
|
val xgbClassifier = new XGBoostClassifier()
|
|
val xgbClassifier = new XGBoostClassifier()
|
|
.setEta(0.01f)
|
|
.setEta(0.01f)
|
|
|
|
+ .setMissing(0.0f)
|
|
|
|
+ .setMaxDepth(5)
|
|
|
|
+ .setNumRound(1000)
|
|
.setSubsample(0.8)
|
|
.setSubsample(0.8)
|
|
.setColsampleBytree(0.8)
|
|
.setColsampleBytree(0.8)
|
|
.setScalePosWeight(1)
|
|
.setScalePosWeight(1)
|
|
- .setSeed(2024)
|
|
|
|
- .setMissing(0.0f)
|
|
|
|
|
|
+ .setObjective("binary:logistic")
|
|
|
|
+ .setEvalMetric("auc")
|
|
.setFeaturesCol("features")
|
|
.setFeaturesCol("features")
|
|
.setLabelCol("label")
|
|
.setLabelCol("label")
|
|
- .setMaxDepth(5)
|
|
|
|
- .setObjective("binary:logistic")
|
|
|
|
.setNthread(1)
|
|
.setNthread(1)
|
|
- .setNumWorkers(1)
|
|
|
|
- .setNumRound(100)
|
|
|
|
|
|
+ .setNumWorkers(22)
|
|
|
|
|
|
// 训练模型
|
|
// 训练模型
|
|
- val model = xgbClassifier.fit(trainData)
|
|
|
|
|
|
+ val model = xgbClassifier.fit(xgbInput)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
|
|
- // 显示预测结果
|
|
|
|
- val predictions = model.transform(testData)
|
|
|
|
- predictions.show(100)
|
|
|
|
}
|
|
}
|
|
catch {
|
|
catch {
|
|
case e: Throwable => e.printStackTrace()
|
|
case e: Throwable => e.printStackTrace()
|
|
@@ -112,17 +88,15 @@ object XGBoostTrain {
|
|
var source: Option[Source] = None
|
|
var source: Option[Source] = None
|
|
try {
|
|
try {
|
|
source = Some(Source.fromURL(filePath))
|
|
source = Some(Source.fromURL(filePath))
|
|
- source.get.getLines().mkString("\n")
|
|
|
|
|
|
+ return source.get.getLines().mkString("\n")
|
|
}
|
|
}
|
|
catch {
|
|
catch {
|
|
- case e: Exception => {
|
|
|
|
- println("文件读取异常: " + e.toString)
|
|
|
|
- ""
|
|
|
|
- }
|
|
|
|
|
|
+ case e: Exception => println("文件读取异常: " + e.toString)
|
|
}
|
|
}
|
|
finally {
|
|
finally {
|
|
source.foreach(_.close())
|
|
source.foreach(_.close())
|
|
}
|
|
}
|
|
|
|
+ ""
|
|
}
|
|
}
|
|
|
|
|
|
private def containsAny(list: Iterable[String], s: String): Boolean = {
|
|
private def containsAny(list: Iterable[String], s: String): Boolean = {
|
|
@@ -133,4 +107,25 @@ object XGBoostTrain {
|
|
}
|
|
}
|
|
false
|
|
false
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+ private def dataMap(data: RDD[String], featureNameList: List[String]): RDD[Row] = {
|
|
|
|
+ data.map(r => {
|
|
|
|
+ val line: Array[String] = StringUtils.split(r, "\t")
|
|
|
|
+ val label: Int = NumberUtils.toInt(line(0))
|
|
|
|
+
|
|
|
|
+ val map: Map[String, Double] = line.drop(1).map { entry =>
|
|
|
|
+ val Array(key, value) = entry.split(":")
|
|
|
|
+ key -> NumberUtils.toDouble(value, 0.0)
|
|
|
|
+ }.toMap
|
|
|
|
+
|
|
|
|
+ val v: Array[Any] = Array.ofDim[Any](featureNameList.length + 1)
|
|
|
|
+ v(0) = label
|
|
|
|
+
|
|
|
|
+ for (index <- featureNameList.indices) {
|
|
|
|
+ v(index + 1) = map.getOrElse(featureNameList(index), 0.0)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ Row.fromSeq(v)
|
|
|
|
+ })
|
|
|
|
+ }
|
|
}
|
|
}
|