فهرست منبع

Update train_01_xgb_ad_20250104: support testonly

StrayWarrior 3 ماه پیش
والد
کامیت
2d6703ab6a

+ 35 - 27
recommend-model-produce/src/main/scala/com/tzld/piaoquan/recommend/model/train_01_xgb_ad_20250104.scala

@@ -37,6 +37,7 @@ object train_01_xgb_ad_20250104 {
     val func_metric = param.getOrElse("func_metric", "auc")
     val repartition = param.getOrElse("repartition", "20").toInt
     val negSampleRate = param.getOrElse("negSampleRate", "1").toDouble
+    val testOnly = param.getOrElse("testOnly", "false").toBoolean
 
     val loader = getClass.getClassLoader
     val resourceUrl = loader.getResource(featureFile)
@@ -63,34 +64,41 @@ object train_01_xgb_ad_20250104 {
     )
     val schema = DataTypes.createStructType(fields)
 
-    val trainData = createData4Ad(
-      sc.textFile(trainPath),
-      features,
-      negSampleRate
-    )
-    // println("path %s, train data size:%d".format(trainPath, trainData.count()))
-    val trainDataSet: Dataset[Row] = spark.createDataFrame(trainData, schema)
+    var model: XGBoostClassificationModel = null
     val vectorAssembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
-    val xgbInput = vectorAssembler.transform(trainDataSet).select("features", "label")
-
-    val xgbClassifier = new XGBoostClassifier()
-      .setEta(eta)
-      .setGamma(gamma)
-      .setMissing(0.0f)
-      .setMaxDepth(max_depth)
-      .setNumRound(num_round)
-      .setSubsample(0.8)
-      .setColsampleBytree(0.8)
-      .setScalePosWeight(1)
-      .setObjective(func_object)
-      .setEvalMetric(func_metric)
-      .setFeaturesCol("features")
-      .setLabelCol("label")
-      .setNthread(1)
-      .setNumWorkers(num_worker)
-      .setSeed(2024)
-      .setMinChildWeight(1)
-    val model = xgbClassifier.fit(xgbInput)
+
+    if (!testOnly) {
+      val trainData = createData4Ad(
+        sc.textFile(trainPath),
+        features,
+        negSampleRate
+      )
+      // println("path %s, train data size:%d".format(trainPath, trainData.count()))
+      val trainDataSet: Dataset[Row] = spark.createDataFrame(trainData, schema)
+      val xgbInput = vectorAssembler.transform(trainDataSet).select("features", "label")
+
+      val xgbClassifier = new XGBoostClassifier()
+        .setEta(eta)
+        .setGamma(gamma)
+        .setMissing(0.0f)
+        .setMaxDepth(max_depth)
+        .setNumRound(num_round)
+        .setSubsample(0.8)
+        .setColsampleBytree(0.8)
+        .setScalePosWeight(1)
+        .setObjective(func_object)
+        .setEvalMetric(func_metric)
+        .setFeaturesCol("features")
+        .setLabelCol("label")
+        .setNthread(1)
+        .setNumWorkers(num_worker)
+        .setSeed(2024)
+        .setMinChildWeight(1)
+      model = xgbClassifier.fit(xgbInput)
+    } else {
+      model = XGBoostClassificationModel.load()
+    }
+
 
     val testData = createData4Ad(
       sc.textFile(testPath),