|
@@ -42,7 +42,7 @@ object recsys_01_ros_multi_class_xgb_train {
|
|
|
val func_object = param.getOrElse("func_object", "multi:softprob")
|
|
|
val func_metric = param.getOrElse("func_metric", "auc")
|
|
|
val repartition = param.getOrElse("repartition", "20").toInt
|
|
|
- val numClass = param.getOrElse("numClass", "100").toInt
|
|
|
+ val numClass = param.getOrElse("numClass", "8").toInt
|
|
|
val subsample = param.getOrElse("subsample", "0.95").toDouble
|
|
|
val modelPath = param.getOrElse("modelPath", "/dw/recommend/model/45_recommend_model/")
|
|
|
val modelFile = param.getOrElse("modelFile", "model.tar.gz")
|
|
@@ -106,6 +106,7 @@ object recsys_01_ros_multi_class_xgb_train {
|
|
|
.setNumWorkers(num_worker)
|
|
|
.setSeed(2024)
|
|
|
.setMinChildWeight(1)
|
|
|
+ .setNumClass(numClass)
|
|
|
|
|
|
val model = xgbClassifier.fit(xgbInput)
|
|
|
|