|
@@ -37,6 +37,7 @@ object train_01_xgb_ad_20250104 {
|
|
|
val func_metric = param.getOrElse("func_metric", "auc")
|
|
|
val repartition = param.getOrElse("repartition", "20").toInt
|
|
|
val negSampleRate = param.getOrElse("negSampleRate", "1").toDouble
|
|
|
+ val testOnly = param.getOrElse("testOnly", "false").toBoolean
|
|
|
|
|
|
val loader = getClass.getClassLoader
|
|
|
val resourceUrl = loader.getResource(featureFile)
|
|
@@ -63,34 +64,41 @@ object train_01_xgb_ad_20250104 {
|
|
|
)
|
|
|
val schema = DataTypes.createStructType(fields)
|
|
|
|
|
|
- val trainData = createData4Ad(
|
|
|
- sc.textFile(trainPath),
|
|
|
- features,
|
|
|
- negSampleRate
|
|
|
- )
|
|
|
- // println("path %s, train data size:%d".format(trainPath, trainData.count()))
|
|
|
- val trainDataSet: Dataset[Row] = spark.createDataFrame(trainData, schema)
|
|
|
+ var model: XGBoostClassificationModel = null
|
|
|
val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
|
|
|
- val xgbInput = vectorAssembler.transform(trainDataSet).select("features", "label")
|
|
|
-
|
|
|
- val xgbClassifier = new XGBoostClassifier()
|
|
|
- .setEta(eta)
|
|
|
- .setGamma(gamma)
|
|
|
- .setMissing(0.0f)
|
|
|
- .setMaxDepth(max_depth)
|
|
|
- .setNumRound(num_round)
|
|
|
- .setSubsample(0.8)
|
|
|
- .setColsampleBytree(0.8)
|
|
|
- .setScalePosWeight(1)
|
|
|
- .setObjective(func_object)
|
|
|
- .setEvalMetric(func_metric)
|
|
|
- .setFeaturesCol("features")
|
|
|
- .setLabelCol("label")
|
|
|
- .setNthread(1)
|
|
|
- .setNumWorkers(num_worker)
|
|
|
- .setSeed(2024)
|
|
|
- .setMinChildWeight(1)
|
|
|
- val model = xgbClassifier.fit(xgbInput)
|
|
|
+
|
|
|
+ if (!testOnly) {
|
|
|
+ val trainData = createData4Ad(
|
|
|
+ sc.textFile(trainPath),
|
|
|
+ features,
|
|
|
+ negSampleRate
|
|
|
+ )
|
|
|
+ // println("path %s, train data size:%d".format(trainPath, trainData.count()))
|
|
|
+ val trainDataSet: Dataset[Row] = spark.createDataFrame(trainData, schema)
|
|
|
+ val xgbInput = vectorAssembler.transform(trainDataSet).select("features", "label")
|
|
|
+
|
|
|
+ val xgbClassifier = new XGBoostClassifier()
|
|
|
+ .setEta(eta)
|
|
|
+ .setGamma(gamma)
|
|
|
+ .setMissing(0.0f)
|
|
|
+ .setMaxDepth(max_depth)
|
|
|
+ .setNumRound(num_round)
|
|
|
+ .setSubsample(0.8)
|
|
|
+ .setColsampleBytree(0.8)
|
|
|
+ .setScalePosWeight(1)
|
|
|
+ .setObjective(func_object)
|
|
|
+ .setEvalMetric(func_metric)
|
|
|
+ .setFeaturesCol("features")
|
|
|
+ .setLabelCol("label")
|
|
|
+ .setNthread(1)
|
|
|
+ .setNumWorkers(num_worker)
|
|
|
+ .setSeed(2024)
|
|
|
+ .setMinChildWeight(1)
|
|
|
+ model = xgbClassifier.fit(xgbInput)
|
|
|
+ } else {
|
|
|
+ model = XGBoostClassificationModel.load()
|
|
|
+ }
|
|
|
+
|
|
|
|
|
|
val testData = createData4Ad(
|
|
|
sc.textFile(testPath),
|