|
@@ -1,5 +1,6 @@
|
|
|
package com.tzld.piaoquan.recommend.model.produce.xgboost;
|
|
|
|
|
|
+import com.tzld.piaoquan.recommend.model.produce.service.OSSService;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassificationModel;
|
|
|
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier;
|
|
@@ -59,27 +60,6 @@ public class XGBoostTrainLocalTest {
|
|
|
file = "/Users/dingyunpeng/Desktop/part-00099.gz";
|
|
|
JavaRDD<String> rdd = jsc.textFile(file);
|
|
|
|
|
|
- // 将 RDD[LabeledPoint] 转换为 JavaRDD<Row>
|
|
|
-// JavaRDD<Row> rowRDD = rdd.map(s -> {
|
|
|
-// String[] line = StringUtils.split(s, '\t');
|
|
|
-// int label = NumberUtils.toInt(line[0]);
|
|
|
-// // 选特征
|
|
|
-// Map<String, Double> map = new HashMap<>();
|
|
|
-// for (int i = 1; i < line.length; i++) {
|
|
|
-// String[] fv = StringUtils.split(line[i], ':');
|
|
|
-// map.put(fv[0], NumberUtils.toDouble(fv[1], 0.0));
|
|
|
-// }
|
|
|
-//
|
|
|
-// int[] indices = new int[features.length];
|
|
|
-// double[] values = new double[features.length];
|
|
|
-// for (int i = 0; i < features.length; i++) {
|
|
|
-// indices[i] = i;
|
|
|
-// values[i] = map.getOrDefault(features[i], 0.0);
|
|
|
-// }
|
|
|
-// SparseVector vector = new SparseVector(indices.length, indices, values);
|
|
|
-// return RowFactory.create(label, vector);
|
|
|
-// });
|
|
|
-
|
|
|
JavaRDD<Row> rowRDD = rdd.map(s -> {
|
|
|
String[] line = StringUtils.split(s, '\t');
|
|
|
int label = NumberUtils.toInt(line[0]);
|
|
@@ -142,10 +122,21 @@ public class XGBoostTrainLocalTest {
|
|
|
// 训练模型
|
|
|
XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
|
|
|
|
|
|
+ // 保存模型
|
|
|
+ String path = "/Users/dingyunpeng/Desktop/model";
|
|
|
+ model.save(path);
|
|
|
+
|
|
|
+ OSSService ossService = new OSSService();
|
|
|
+ String bucketName = "";
|
|
|
+ String ossPath = "";
|
|
|
+ ossService.upload(bucketName, path, ossPath);
|
|
|
+
|
|
|
// 显示预测结果
|
|
|
- Dataset<Row> predictions = model.transform(assembledData);
|
|
|
+ XGBoostClassificationModel model2 = XGBoostClassificationModel.load(path);
|
|
|
+ Dataset<Row> predictions = model2.transform(assembledData);
|
|
|
predictions.select("label", "prediction", "features", "rawPrediction", "probability").show(500);
|
|
|
|
|
|
+ spark.close();
|
|
|
|
|
|
} catch (Throwable e) {
|
|
|
log.error("", e);
|