|
@@ -3,6 +3,7 @@ package com.aliyun.odps.spark.ad.xgboost.v20240808
|
|
|
import com.aliyun.odps.spark.examples.myUtils.ParamUtils
|
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
|
|
|
import org.apache.commons.lang3.math.NumberUtils
|
|
|
+import org.apache.spark.SparkConf
|
|
|
import org.apache.spark.ml.feature.VectorAssembler
|
|
|
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
|
|
|
import org.apache.spark.sql.{Row, SparkSession}
|
|
@@ -16,7 +17,12 @@ object XGBoostTrain {
|
|
|
|
|
|
val param = ParamUtils.parseArgs(args)
|
|
|
|
|
|
+ val conf = new SparkConf()
|
|
|
+ .set("spark.yarn.appMasterEnv.PYSPARK_PYTHON", "/usr/bin/python2.7")
|
|
|
+ .set("spark.yarn.appMasterEnv.PYSPARK_DRIVER_PYTHON", "/usr/bin/python2.7")
|
|
|
+
|
|
|
val spark = SparkSession.builder()
|
|
|
+ .config(conf)
|
|
|
.appName("XGBoostTrain")
|
|
|
.getOrCreate()
|
|
|
val sc = spark.sparkContext
|
|
@@ -87,8 +93,8 @@ object XGBoostTrain {
|
|
|
.setMaxDepth(5)
|
|
|
.setObjective("binary:logistic")
|
|
|
.setNthread(1)
|
|
|
- .setNumRound(100)
|
|
|
.setNumWorkers(1)
|
|
|
+ .setNumRound(100)
|
|
|
|
|
|
|
|
|
val model = xgbClassifier.fit(trainData)
|