|
@@ -74,7 +74,8 @@ public class XGBoostService {
|
|
public void predict(String[] args) {
|
|
public void predict(String[] args) {
|
|
try {
|
|
try {
|
|
|
|
|
|
- Dataset<Row> assembledData = dataset("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00098.gz");
|
|
|
|
|
|
+ Dataset<Row> assembledData =
|
|
|
|
+ dataset("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00098.gz").select("features");
|
|
log.info("测试样本 show");
|
|
log.info("测试样本 show");
|
|
assembledData.show();
|
|
assembledData.show();
|
|
|
|
|
|
@@ -91,6 +92,8 @@ public class XGBoostService {
|
|
// 显示预测结果
|
|
// 显示预测结果
|
|
XGBoostClassificationModel model2 = XGBoostClassificationModel.load("file://" + destDir);
|
|
XGBoostClassificationModel model2 = XGBoostClassificationModel.load("file://" + destDir);
|
|
model2.setMissing(0.0f);
|
|
model2.setMissing(0.0f);
|
|
|
|
+ model2.setFeaturesCol("features");
|
|
|
|
+
|
|
Dataset<Row> predictions = model2.transform(assembledData);
|
|
Dataset<Row> predictions = model2.transform(assembledData);
|
|
predictions.select("label", "prediction", "rawPrediction", "probability", "features").show(500);
|
|
predictions.select("label", "prediction", "rawPrediction", "probability", "features").show(500);
|
|
|
|
|
|
@@ -147,7 +150,7 @@ public class XGBoostService {
|
|
|
|
|
|
JavaRDD<Row> rowRDD = rdd.map(s -> {
|
|
JavaRDD<Row> rowRDD = rdd.map(s -> {
|
|
String[] line = StringUtils.split(s, '\t');
|
|
String[] line = StringUtils.split(s, '\t');
|
|
- int label = NumberUtils.toInt(line[0]);
|
|
|
|
|
|
+ double label = NumberUtils.toDouble(line[0]);
|
|
// 选特征
|
|
// 选特征
|
|
Map<String, Double> map = new HashMap<>();
|
|
Map<String, Double> map = new HashMap<>();
|
|
for (int i = 1; i < line.length; i++) {
|
|
for (int i = 1; i < line.length; i++) {
|
|
@@ -167,7 +170,7 @@ public class XGBoostService {
|
|
log.info("rowRDD count {}", rowRDD.count());
|
|
log.info("rowRDD count {}", rowRDD.count());
|
|
// 将 JavaRDD<Row> 转换为 Dataset<Row>
|
|
// 将 JavaRDD<Row> 转换为 Dataset<Row>
|
|
List<StructField> fields = new ArrayList<>();
|
|
List<StructField> fields = new ArrayList<>();
|
|
- fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
|
|
|
|
|
|
+ fields.add(DataTypes.createStructField("label", DataTypes.DoubleType, true));
|
|
for (String f : features) {
|
|
for (String f : features) {
|
|
fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
|
|
fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
|
|
}
|
|
}
|
|
@@ -178,7 +181,7 @@ public class XGBoostService {
|
|
.setInputCols(features)
|
|
.setInputCols(features)
|
|
.setOutputCol("features");
|
|
.setOutputCol("features");
|
|
|
|
|
|
- Dataset<Row> assembledData = assembler.transform(dataset);
|
|
|
|
|
|
+ Dataset<Row> assembledData = assembler.transform(dataset).select("features", "label");
|
|
assembledData.show();
|
|
assembledData.show();
|
|
return assembledData;
|
|
return assembledData;
|
|
}
|
|
}
|