|
@@ -37,6 +37,8 @@ object recsys_01_xgb_train {
|
|
|
val func_metric = param.getOrElse("func_metric", "auc")
|
|
|
val repartition = param.getOrElse("repartition", "20").toInt
|
|
|
val subsample = param.getOrElse("subsample", "0.95").toDouble
|
|
|
+ val modelPath = param.getOrElse("modelPath", "/dw/recommend/model/45_recommend_model/")
|
|
|
+ val modelFile = param.getOrElse("modelFile", "model.tar.gz")
|
|
|
|
|
|
val loader = getClass.getClassLoader
|
|
|
val resourceUrl = loader.getResource(featureFile)
|
|
@@ -97,6 +99,9 @@ object recsys_01_xgb_train {
|
|
|
.setMinChildWeight(1)
|
|
|
val model = xgbClassifier.fit(xgbInput)
|
|
|
|
|
|
+ if (modelPath.nonEmpty && modelFile.nonEmpty) {
|
|
|
+ model.write.overwrite().save(modelPath)
|
|
|
+ }
|
|
|
|
|
|
val testData = createData4Ad(
|
|
|
sc.textFile(testPath),
|