瀏覽代碼

feat:修改环境变量

zhaohaipeng 9 月之前
父節點
當前提交
9fc01d9baa
共有 1 個文件被更改,包括 7 次插入1 次删除
  1. 7 1
      src/main/scala/com/aliyun/odps/spark/ad/xgboost/v20240808/XGBoostTrain.scala

+ 7 - 1
src/main/scala/com/aliyun/odps/spark/ad/xgboost/v20240808/XGBoostTrain.scala

@@ -3,6 +3,7 @@ package com.aliyun.odps.spark.ad.xgboost.v20240808
 import com.aliyun.odps.spark.examples.myUtils.ParamUtils
 import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
 import org.apache.commons.lang3.math.NumberUtils
+import org.apache.spark.SparkConf
 import org.apache.spark.ml.feature.VectorAssembler
 import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
 import org.apache.spark.sql.{Row, SparkSession}
@@ -16,7 +17,12 @@ object XGBoostTrain {
 
       val param = ParamUtils.parseArgs(args)
 
+      val conf = new SparkConf()
+        .set("spark.yarn.appMasterEnv.PYSPARK_PYTHON", "/usr/bin/python2.7")
+        .set("spark.yarn.appMasterEnv.PYSPARK_DRIVER_PYTHON", "/usr/bin/python2.7")
+
       val spark = SparkSession.builder()
+        .config(conf)
         .appName("XGBoostTrain")
         .getOrCreate()
       val sc = spark.sparkContext
@@ -87,8 +93,8 @@ object XGBoostTrain {
         .setMaxDepth(5)
         .setObjective("binary:logistic")
         .setNthread(1)
-        .setNumRound(100)
         .setNumWorkers(1)
+        .setNumRound(100)
 
       // 训练模型
       val model = xgbClassifier.fit(trainData)