|
@@ -1,5 +1,6 @@
|
|
|
package com.tzld.piaoquan.recommend.model.produce.xgboost;
|
|
|
|
|
|
+import com.tzld.piaoquan.recommend.model.produce.service.OSSService;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
|
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
|
|
@@ -59,27 +60,6 @@ public class XGBoostTrainLocalTest {
|
|
|
file = "/Users/dingyunpeng/Desktop/part-00099.gz";
|
|
|
JavaRDD<String> rdd = jsc.textFile(file);
|
|
|
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
JavaRDD<Row> rowRDD = rdd.map(s -> {
|
|
|
String[] line = StringUtils.split(s, '\t');
|
|
|
int label = NumberUtils.toInt(line[0]);
|
|
@@ -142,10 +122,21 @@ public class XGBoostTrainLocalTest {
|
|
|
|
|
|
XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
|
|
|
|
|
|
+
|
|
|
+ String path = "/Users/dingyunpeng/Desktop/model";
|
|
|
+ model.save(path);
|
|
|
+
|
|
|
+ OSSService ossService = new OSSService();
|
|
|
+ String bucketName = "";
|
|
|
+ String ossPath = "";
|
|
|
+ ossService.upload(bucketName, path, ossPath);
|
|
|
+
|
|
|
|
|
|
- Dataset<Row> predictions = model.transform(assembledData);
|
|
|
+ XGBoostClassificationModel model2 = XGBoostClassificationModel.load(path);
|
|
|
+ Dataset<Row> predictions = model2.transform(assembledData);
|
|
|
predictions.select("label", "prediction", "features", "rawPrediction", "probability").show(500);
|
|
|
|
|
|
+ spark.close();
|
|
|
|
|
|
} catch (Throwable e) {
|
|
|
log.error("", e);
|