|
@@ -33,10 +33,12 @@ public class XGBoostService {
|
|
|
|
|
|
public void train(String[] args) {
|
|
|
try {
|
|
|
- Dataset<Row> assembledData = dataset("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00099.gz");
|
|
|
- log.info("训练样本 show");
|
|
|
- assembledData.show();
|
|
|
- // 创建 XGBoostClassifier 对象
|
|
|
+ CMDService cmd = new CMDService();
|
|
|
+ Map<String, String> argMap = cmd.parse(args);
|
|
|
+ String path = argMap.get("path");
|
|
|
+ // 训练
|
|
|
+ Dataset<Row> trainData = dataset(path);
|
|
|
+ trainData.show();
|
|
|
XGBoostClassifier xgbClassifier = new XGBoostClassifier()
|
|
|
.setEta(0.01f)
|
|
|
.setSubsample(0.8)
|
|
@@ -51,20 +53,17 @@ public class XGBoostService {
|
|
|
.setNthread(1)
|
|
|
.setNumRound(100)
|
|
|
.setNumWorkers(1);
|
|
|
-
|
|
|
-
|
|
|
- // 训练模型
|
|
|
- XGBoostClassificationModel model = xgbClassifier.fit(assembledData);
|
|
|
+ XGBoostClassificationModel model = xgbClassifier.fit(trainData);
|
|
|
|
|
|
// 保存模型
|
|
|
- String path = "/root/recommend-model/modeltrain";
|
|
|
- model.write().overwrite().save("file://" + path);
|
|
|
- String outputPath = "/root/recommend-model/model.tar.gz";
|
|
|
- CompressUtil.compressDirectoryToGzip(path, outputPath);
|
|
|
+ String modelPath = "/root/recommend-model/modeltrain";
|
|
|
+ model.write().overwrite().save("file://" + modelPath);
|
|
|
+ String gzPath = "/root/recommend-model/model.tar.gz";
|
|
|
+ CompressUtil.compressDirectoryToGzip(modelPath, gzPath);
|
|
|
String bucketName = "art-test-video";
|
|
|
String ossPath = "test/model.tar.gz";
|
|
|
OSSService ossService = new OSSService();
|
|
|
- ossService.upload(bucketName, outputPath, ossPath);
|
|
|
+ ossService.upload(bucketName, gzPath, ossPath);
|
|
|
|
|
|
} catch (Throwable e) {
|
|
|
log.error("", e);
|
|
@@ -74,32 +73,37 @@ public class XGBoostService {
|
|
|
public void predict(String[] args) {
|
|
|
try {
|
|
|
|
|
|
- Dataset<Row> assembledData = dataset("/dw/recommend/model/33_ad_train_data_v4/20240726/part-00098.gz");
|
|
|
- log.info("测试样本 show");
|
|
|
- assembledData.show();
|
|
|
+ CMDService cmd = new CMDService();
|
|
|
+ Map<String, String> argMap = cmd.parse(args);
|
|
|
+ String path = argMap.get("path");
|
|
|
|
|
|
- // 保存模型
|
|
|
+
|
|
|
+ // 加载模型
|
|
|
String bucketName = "art-test-video";
|
|
|
String objectName = "test/model.tar.gz";
|
|
|
OSSService ossService = new OSSService();
|
|
|
|
|
|
- String destPath = "/root/recommend-model/model2.tar.gz";
|
|
|
- ossService.download(bucketName, destPath, objectName);
|
|
|
- String destDir = "/root/recommend-model/modelpredict";
|
|
|
- CompressUtil.decompressGzFile(destPath, destDir);
|
|
|
+ String gzPath = "/root/recommend-model/model2.tar.gz";
|
|
|
+ ossService.download(bucketName, gzPath, objectName);
|
|
|
+ String modelDir = "/root/recommend-model/modelpredict";
|
|
|
+ CompressUtil.decompressGzFile(gzPath, modelDir);
|
|
|
+
|
|
|
+ XGBoostClassificationModel model = XGBoostClassificationModel.load("file://" + modelDir);
|
|
|
+ model.setMissing(0.0f)
|
|
|
+ .setFeaturesCol("features");
|
|
|
+
|
|
|
|
|
|
- // 显示预测结果
|
|
|
- XGBoostClassificationModel model2 = XGBoostClassificationModel.load("file://" + destDir);
|
|
|
- Dataset<Row> predictions = model2.transform(assembledData);
|
|
|
- predictions.select("label", "prediction", "rawPrediction", "probability", "features").show(500);
|
|
|
+ // 预测
|
|
|
+ Dataset<Row> predictData = dataset(path);
|
|
|
+ predictData.show();
|
|
|
+ Dataset<Row> predictions = model.transform(predictData);
|
|
|
+ predictions.show();
|
|
|
|
|
|
// 计算AUC
|
|
|
- Dataset<Row> selected = predictions.select("label", "rawPrediction");
|
|
|
BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator()
|
|
|
.setLabelCol("label")
|
|
|
- .setRawPredictionCol("rawPrediction")
|
|
|
- .setMetricName("areaUnderROC");
|
|
|
- double auc = evaluator.evaluate(selected);
|
|
|
+ .setRawPredictionCol("rawPrediction");
|
|
|
+ double auc = evaluator.evaluate(predictions);
|
|
|
log.info("AUC: {}", auc);
|
|
|
|
|
|
} catch (Throwable e) {
|
|
@@ -108,7 +112,8 @@ public class XGBoostService {
|
|
|
}
|
|
|
|
|
|
private static Dataset<Row> dataset(String path) {
|
|
|
- String[] features = {"cpa",
|
|
|
+ String[] features = {
|
|
|
+ "cpa",
|
|
|
"b2_1h_ctr",
|
|
|
"b2_1h_ctcvr",
|
|
|
"b2_1h_cvr",
|
|
@@ -146,7 +151,7 @@ public class XGBoostService {
|
|
|
|
|
|
JavaRDD<Row> rowRDD = rdd.map(s -> {
|
|
|
String[] line = StringUtils.split(s, '\t');
|
|
|
- int label = NumberUtils.toInt(line[0]);
|
|
|
+ double label = NumberUtils.toDouble(line[0]);
|
|
|
// 选特征
|
|
|
Map<String, Double> map = new HashMap<>();
|
|
|
for (int i = 1; i < line.length; i++) {
|
|
@@ -166,7 +171,7 @@ public class XGBoostService {
|
|
|
log.info("rowRDD count {}", rowRDD.count());
|
|
|
// 将 JavaRDD<Row> 转换为 Dataset<Row>
|
|
|
List<StructField> fields = new ArrayList<>();
|
|
|
- fields.add(DataTypes.createStructField("label", DataTypes.IntegerType, true));
|
|
|
+ fields.add(DataTypes.createStructField("label", DataTypes.DoubleType, true));
|
|
|
for (String f : features) {
|
|
|
fields.add(DataTypes.createStructField(f, DataTypes.DoubleType, true));
|
|
|
}
|
|
@@ -177,7 +182,7 @@ public class XGBoostService {
|
|
|
.setInputCols(features)
|
|
|
.setOutputCol("features");
|
|
|
|
|
|
- Dataset<Row> assembledData = assembler.transform(dataset);
|
|
|
+ Dataset<Row> assembledData = assembler.transform(dataset).select("features", "label");
|
|
|
assembledData.show();
|
|
|
return assembledData;
|
|
|
}
|